比较提交

...
此合并请求有变更与目标分支冲突。
/ml-agents/mlagents/torch_utils/torch.py
/ml-agents/mlagents/trainers/learn.py
/ml-agents/mlagents/trainers/policy/torch_policy.py
/ml-agents/mlagents/trainers/sac/optimizer_torch.py
/ml-agents/mlagents/trainers/torch/utils.py
/config/sac/CrawlerStatic.yaml

6 次代码提交

作者 SHA1 备注 提交日期
Ervin Teng 3a7cd3ad Merge experiments 4 年前
Ervin Teng 228ea059 Try futures in Optimizer 4 年前
Ervin Teng f59f35ea Remove stuff in policy 4 年前
Ervin Teng a305a41b Try futures in Optimizer 4 年前
Ervin Teng fdc887a1 Some experimental stuff 4 年前
Ervin Teng 60eacc0d Merge branch 'master' into develop-adjust-cpu-settings 4 年前
共有 6 个文件被更改,包括 130 次插入31 次删除
  1. 4
      config/sac/CrawlerStatic.yaml
  2. 11
      ml-agents/mlagents/torch_utils/torch.py
  3. 106
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  4. 11
      ml-agents/mlagents/trainers/policy/torch_policy.py
  5. 19
      ml-agents/mlagents/trainers/torch/utils.py
  6. 10
      ml-agents/mlagents/trainers/learn.py

4
config/sac/CrawlerStatic.yaml


gamma: 0.995
strength: 1.0
keep_checkpoints: 5
max_steps: 3000000
max_steps: 100000
threaded: true
threaded: false

11
ml-agents/mlagents/torch_utils/torch.py


import os
# Detect availability of torch package here.
# NOTE: this try/except is temporary until torch is required for ML-Agents.
try:

torch.set_num_threads(1)
if "TORCH_NUM_THREADS" in os.environ:
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS")))
if "TORCH_NUM_INTEROP" in os.environ:
torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP")))
# torch.set_num_interop_threads(4)
# os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
# os.environ["KMP_BLOCKTIME"] = "1"
# Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701
# pylint: disable=E1101

106
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple, Optional
from mlagents.torch_utils import torch, nn, default_device
import time
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
q1_out, _ = self.q1_network(
q1_fut = torch.jit._fork(
self.q1_network,
vec_inputs,
vis_inputs,
actions=actions,

q2_out, _ = self.q2_network(
q2_fut = torch.jit._fork(
self.q2_network,
vec_inputs,
vis_inputs,
actions=actions,

q1_out, _ = torch.jit._wait(q1_fut)
q2_out, _ = torch.jit._wait(q2_fut)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):

policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
)
(
dummy_vec_obs,
dummy_vis_obs,
dummy_masks,
dummy_memories,
) = ModelUtils.create_dummy_input(self.policy)
example_inputs = (
dummy_vec_obs,
dummy_vis_obs,
torch.zeros((1, sum(self.act_size)))
if self.policy.use_continuous_act
else None,
dummy_memories,
torch.tensor(self.policy.sequence_length),
)
self.value_network = torch.jit.trace(
self.value_network, example_inputs, strict=False
)
self.target_network = ValueNetwork(

indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
# t0 = time.time()
rewards = {}
for name in self.reward_signals:
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

next_memories = None
# Q network memories are 0'ed out, since we don't have them during inference.
q_memories = (
torch.zeros_like(next_memories) if next_memories is not None else None
torch.zeros_like(next_memories)
if next_memories is not None
else torch.empty((1, 1, 0))
)
vis_obs: List[torch.Tensor] = []

next_vis_obs.append(next_vis_ob)
# Copy normalizers from policy
self.value_network.q1_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
self.value_network.q2_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
# self.value_network.q1_network.network_body.copy_normalization(
# self.policy.actor_critic.network_body
# )
# self.value_network.q2_network.network_body.copy_normalization(
# self.policy.actor_critic.network_body
# )
# t1 = time.time()
(
sampled_actions,
log_probs,

seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
)
# t2 = time.time()
q1p_out, q2p_out = self.value_network(
# q1p_out, q2p_out = self.value_network(
# vec_obs,
# torch.tensor(vis_obs),
# sampled_actions,
# q_memories,
# torch.tensor(self.policy.sequence_length),
# )
qp_fut = torch.jit._fork(
self.value_network,
vis_obs,
torch.tensor(vis_obs),
memories=q_memories,
sequence_length=self.policy.sequence_length,
q_memories,
torch.tensor(self.policy.sequence_length),
q1_out, q2_out = self.value_network(
q_fut = torch.jit._fork(
self.value_network,
vis_obs,
torch.tensor(vis_obs),
memories=q_memories,
sequence_length=self.policy.sequence_length,
q_memories,
torch.tensor(self.policy.sequence_length),
# q1_out, q2_out = self.value_network(
# vec_obs,
# torch.tensor(vis_obs),
# squeezed_actions,
# q_memories,
# torch.tensor(self.policy.sequence_length),
# )
q1p_out, q2p_out = torch.jit._wait(qp_fut)
q1_out, q2_out = torch.jit._wait(q_fut)
q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():

)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)
# t3 = time.time()
with torch.no_grad():
target_values, _ = self.target_network(
next_vec_obs,

masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
# t4 = time.time()
q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
)

entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
total_value_loss = q1_loss + q2_loss + value_loss
# t5 = time.time()
# t6 = time.time()
# t7 = time.time()
# t8 = time.time()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
update_stats = {

"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
"Policy/Learning Rate": decay_lr,
}
# t9 = time.time()
# print(
# t9 - t8,
# t8 - t7,
# t7 - t6,
# t6 - t5,
# t5 - t4,
# t4 - t3,
# t3 - t2,
# t2 - t1,
# t1 - t0,
# )
return update_stats
def update_reward_signals(

11
ml-agents/mlagents/trainers/policy/torch_policy.py


self.m_size = self.actor_critic.memory_size
self.actor_critic.to(default_device())
# dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input(
# self
# )
# dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem)
# critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem)
# # example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy}
# self.sample_actions = torch.jit.trace(
# self.sample_actions,
# dist_val_dummy,
# )
@property
def export_memory_size(self) -> int:

19
ml-agents/mlagents/trainers/torch/utils.py


return (tensor.T * masks).sum() / torch.clamp(
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
)
@staticmethod
def create_dummy_input(policy):
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [policy.vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
dummy_vis_obs = [
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
for shape in policy.behavior_spec.observation_shapes
if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(policy.actor_critic.act_size)])
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [policy.export_memory_size]
)
return dummy_vec_obs, torch.Tensor(dummy_vis_obs), dummy_masks, dummy_memories

10
ml-agents/mlagents/trainers/learn.py


import os
import numpy as np
import json
import cProfile
from mlagents_envs.registry import default_registry
from typing import Callable, Optional, List

) -> UnityEnvironment:
# Make sure that each environment gets a different seed
env_seed = seed + worker_id
return UnityEnvironment(
file_name=env_path,
return default_registry["CrawlerStaticTarget"].make(
worker_id=worker_id,
seed=env_seed,
no_graphics=no_graphics,

def main():
run_cli(parse_command_line())
with cProfile.Profile() as pr:
run_cli(parse_command_line())
pr.dump_stats("pytorch_sac.prof")
# For python debugger to directly run this script
正在加载...
取消
保存