浏览代码

Fix SAC and make utility method

/develop/permutepytorch
Ervin Teng 4 年前
当前提交
77c810fb
共有 5 个文件被更改,包括 10 次插入6 次删除
  1. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 4
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  5. 4
      ml-agents/mlagents/trainers/torch/utils.py

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


):
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
# Make sure to permute visual obs, as PyTorch uses NCHW
visual_obs.append(visual_ob.permute([0, 3, 1, 2]))
visual_obs.append(ModelUtils.nhwc_to_nchw(visual_ob))
else:
visual_obs = []

ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
]
next_vis_obs = [
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0).permute([0, 3, 1, 2])
ModelUtils.nhwc_to_nchw(ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0))
for _vis_ob in vec_vis_obs.visual_observations
]

2
ml-agents/mlagents/trainers/policy/torch_policy.py


vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
# Make sure to permute visual obs, as PyTorch uses NCHW
vis_obs = [
torch.as_tensor(vis_ob).permute([0, 3, 1, 2])
ModelUtils.nhwc_to_nchw(torch.as_tensor(vis_ob))
for vis_ob in vec_vis_obs.visual_observations
]
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


):
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
# Make sure to permute visual obs, as PyTorch uses NCHW
vis_obs.append(vis_ob.permute([0, 3, 1, 2]))
vis_obs.append(ModelUtils.nhwc_to_nchw(vis_ob))
else:
vis_obs = []
log_probs, entropy, values = self.policy.evaluate_actions(

4
ml-agents/mlagents/trainers/sac/optimizer_torch.py


):
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
# Make sure to permute visual obs, as PyTorch uses NCHW
vis_obs.append(vis_ob.permute([0, 3, 1, 2]))
vis_obs.append(ModelUtils.nhwc_to_nchw(vis_ob))
next_vis_obs.append(next_vis_ob)
next_vis_obs.append(ModelUtils.nhwc_to_nchw(next_vis_ob))
# Copy normalizers from policy
self.value_network.q1_network.network_body.copy_normalization(

4
ml-agents/mlagents/trainers/torch/utils.py


return (tensor.T * masks).sum() / torch.clamp(
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
)
@staticmethod
def nhwc_to_nchw(tensor: torch.Tensor) -> torch.Tensor:
return tensor.permute([0, 3, 1, 2])
正在加载...
取消
保存