浏览代码

Fix PPO tests

/develop/critic-op-lstm-currentmem
Ervin Teng 4 年前
当前提交
40f51774
共有 1 个文件被更改,包括 9 次插入3 次删除
  1. 12
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py

12
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


RewardSignalUtil.value_estimates_key("extrinsic"),
],
)
# Copy memories to critic memories
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY])
return_stats = optimizer.update(
update_buffer,

RewardSignalUtil.value_estimates_key("curiosity"),
],
)
# Copy memories to critic memories
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY])
optimizer.update(
update_buffer,

action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC,
max_step_complete=True,
)
run_out, final_value_out = optimizer.get_trajectory_value_estimates(
run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates(
if all_memories is not None:
assert len(all_memories) == 15
run_out, final_value_out = optimizer.get_trajectory_value_estimates(
run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=True
)
for key, val in final_value_out.items():

# Check if we ignore terminal states properly
optimizer.reward_signals["extrinsic"].use_terminal_states = False
run_out, final_value_out = optimizer.get_trajectory_value_estimates(
run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
)
for key, val in final_value_out.items():

正在加载...
取消
保存