浏览代码

Fix bug with construct_curr_info (#2490)

* Fix bug with construct_curr_info
* Add more tests
/develop-gpu-test
GitHub 5 年前
当前提交
4bb97e25
共有 2 个文件被更改,包括 8 次插入4 次删除
  1. 2
      ml-agents/mlagents/trainers/rl_trainer.py
  2. 10
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py

2
ml-agents/mlagents/trainers/rl_trainer.py


:return: curr_info: Reconstructed BrainInfo to match agents of next_info.
"""
visual_observations: List[List[Any]] = [
[]
[] for _ in next_info.visual_observations
] # TODO add types to brain.py methods
vector_observations = []
text_observations = []

10
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs")
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs")
def test_rl_trainer(add_policy_outputs, add_rewards_outputs):
@pytest.mark.parametrize("num_vis_obs", [0, 1, 2], ids=["vec", "1 viz", "2 viz"])
def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs):
trainer = create_rl_trainer()
trainer.policy = create_mock_policy()
fake_action_outputs = {

num_agents=2,
num_vector_observations=8,
num_vector_acts=2,
num_vis_observations=1,
num_vis_observations=num_vis_obs,
)
trainer.add_experiences(
create_mock_all_brain_info(mock_braininfo),

num_agents=1,
num_vector_observations=8,
num_vector_acts=2,
num_vis_observations=1,
num_vis_observations=num_vis_obs,
assert len(brain_info.visual_observations) == num_vis_obs
assert len(brain_info.vector_observations) == 1
assert len(brain_info.previous_vector_actions) == 1
# Test end episode
trainer.end_episode()

正在加载...
取消
保存