浏览代码

Improve tests

/develop/lstm-burnin
Ervin Teng 4 年前
当前提交
b3499848
共有 2 个文件被更改,包括 5 次插入5 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_poca.py
  2. 8
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py

2
ml-agents/mlagents/trainers/tests/torch/test_poca.py


optimizer = create_test_poca_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
time_horizon = 15
time_horizon = 30
trajectory = make_fake_trajectory(
length=time_horizon,
observation_specs=optimizer.policy.behavior_spec.observation_specs,

8
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


trainer_settings = attr.evolve(dummy_config)
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings(sequence_length=8, memory_size=10)
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10)
if use_rnn
else None
)

)
# Time horizon is longer than sequence length, make sure to test
# process trajectory on multiple sequences in trajectory + some padding
time_horizon = 15
time_horizon = 30
trajectory = make_fake_trajectory(
length=time_horizon,
observation_specs=optimizer.policy.behavior_spec.observation_specs,

for key, val in run_out.items():
assert type(key) is str
assert len(val) == 15
assert len(val) == time_horizon
assert len(all_memories) == 15
assert len(all_memories) == time_horizon
run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=True

正在加载...
取消
保存