浏览代码

Permute visual obs outside of network

/develop/permutepytorch
Ervin Teng 4 年前
当前提交
3e771cbb
共有 5 个文件被更改,包括 10 次插入7 次删除
  1. 5
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 3
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 3
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  5. 2
      ml-agents/mlagents/trainers/torch/networks.py

5
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


self.policy.actor_critic.network_body.visual_processors
):
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
visual_obs.append(visual_ob)
# Make sure to permute visual obs, as PyTorch uses NCHW
visual_obs.append(visual_ob.permute([0, 3, 1, 2]))
else:
visual_obs = []

ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
]
next_vis_obs = [
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0).permute([0, 3, 1, 2])
for _vis_ob in vec_vis_obs.visual_observations
]

4
ml-agents/mlagents/trainers/policy/torch_policy.py


"""
vec_vis_obs, masks = self._split_decision_step(decision_requests)
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
# Make sure to permute visual obs, as PyTorch uses NCHW
torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations
torch.as_tensor(vis_ob).permute([0, 3, 1, 2])
for vis_ob in vec_vis_obs.visual_observations
]
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
0

3
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


self.policy.actor_critic.network_body.visual_processors
):
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
# Make sure to permute visual obs, as PyTorch uses NCHW
vis_obs.append(vis_ob.permute([0, 3, 1, 2]))
else:
vis_obs = []
log_probs, entropy, values = self.policy.evaluate_actions(

3
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.policy.actor_critic.network_body.visual_processors
):
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
# Make sure to permute visual obs, as PyTorch uses NCHW
vis_obs.append(vis_ob.permute([0, 3, 1, 2]))
next_vis_ob = ModelUtils.list_to_tensor(
batch["next_visual_obs%d" % idx]
)

2
ml-agents/mlagents/trainers/torch/networks.py


for idx, processor in enumerate(self.visual_processors):
vis_input = vis_inputs[idx]
if not torch.onnx.is_in_onnx_export():
vis_input = vis_input.permute([0, 3, 1, 2])
processed_vis = processor(vis_input)
encodes.append(processed_vis)

正在加载...
取消
保存