浏览代码

Fix groupmate obs, add tests

/develop/lstm-burnin
Ervin Teng 4 年前
当前提交
c05ec9af
共有 3 个文件被更改,包括 9 次插入7 次删除
  1. 10
      ml-agents/mlagents/trainers/poca/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/tests/torch/test_poca.py
  3. 4
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py

10
ml-agents/mlagents/trainers/poca/optimizer_torch.py


seq_obs.append(_self_obs[start:end])
self_seq_obs.append(seq_obs)
for groupmate_obs, team_action in zip(obs, actions):
for groupmate_obs, groupmate_action in zip(obs, actions):
for (_obs,) in groupmate_obs:
first_seq_obs = _obs[start:end]
seq_obs.append(first_seq_obs)
for _obs in groupmate_obs:
sliced_seq_obs = _obs[start:end]
seq_obs.append(sliced_seq_obs)
_act = team_action.slice(start, end)
_act = groupmate_action.slice(start, end)
groupmate_seq_act.append(_act)
all_seq_obs = self_seq_obs + groupmate_seq_obs

2
ml-agents/mlagents/trainers/tests/torch/test_poca.py


}
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10)
NetworkSettings.MemorySettings(sequence_length=8, memory_size=10)
if use_rnn
else None
)

4
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


trainer_settings = attr.evolve(dummy_config)
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10)
NetworkSettings.MemorySettings(sequence_length=8, memory_size=10)
if use_rnn
else None
)

optimizer = create_test_ppo_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
# Time horizon is longer than sequence length, make sure to test
# process trajectory on multiple sequences in trajectory + some padding
time_horizon = 15
trajectory = make_fake_trajectory(
length=time_horizon,

正在加载...
取消
保存