浏览代码

Fix SeparateActorCritic and add test

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
cded4c6c
共有 2 个文件被更改,包括 6 次插入2 次删除
  1. 6
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  2. 2
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents/mlagents/trainers/tests/torch/test_networks.py


assert value_out[stream].shape == (1,)
# Test get_dist_and_value
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
dists, value_out, mem_out = actor.get_dist_and_value(
[sample_obs], [], memories=memories
)
if mem_out is not None:
assert mem_out.shape == memories.shape
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:

2
ml-agents/mlagents/trainers/torch/networks.py


vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1)
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
return dists, value_outputs, mem_out

正在加载...
取消
保存