浏览代码

Fix multiple obs

/develop/unified-obs
Ervin Teng 4 年前
当前提交
5a5bd515
共有 2 个文件被更改,包括 3 次插入5 次删除
  1. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 4
      ml-agents/mlagents/trainers/torch/networks.py

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


obs = ModelUtils.list_to_tensor_list(
AgentBuffer.obs_list_to_obs_batch(batch["obs"])
)
next_obs = ModelUtils.list_to_tensor_list(
AgentBuffer.obs_list_to_obs_batch(batch["next_obs"])
)
next_obs = ModelUtils.list_to_tensor_list(next_obs)
memory = torch.zeros([1, 1, self.policy.m_size])

4
ml-agents/mlagents/trainers/torch/networks.py


def forward(
self,
inputs: List[torch.Tensor],
net_inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,

net_input = inputs[idx]
net_input = net_inputs[idx]
if not exporting_to_onnx.is_exporting() and len(net_input.shape) > 3:
net_input = net_input.permute([0, 3, 1, 2])
processed_vec = processor(net_input)

正在加载...
取消
保存