浏览代码

Fixed test code by creating brain_name variable instead of hardcoding

/develop-newnormalization
Andrew Cohen 5 年前
当前提交
b11f04ea
共有 2 个文件被更改,包括 27 次插入22 次删除
  1. 18
      ml-agents/mlagents/trainers/tests/test_bc.py
  2. 31
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py

18
ml-agents/mlagents/trainers/tests/test_bc.py


trainer, env = create_bc_trainer(dummy_config)
# Test add_experiences
returned_braininfo = env.step()
brain_name = "Ball3DBrain"
returned_braininfo["Ball3DBrain"], returned_braininfo["Ball3DBrain"], {}
returned_braininfo[brain_name], returned_braininfo[brain_name], {}
for agent_id in returned_braininfo["Ball3DBrain"].agents:
for agent_id in returned_braininfo[brain_name].agents:
returned_braininfo["Ball3DBrain"].local_done = 12 * [True]
returned_braininfo[brain_name].local_done = 12 * [True]
returned_braininfo["Ball3DBrain"], returned_braininfo["Ball3DBrain"]
returned_braininfo[brain_name], returned_braininfo[brain_name]
for agent_id in returned_braininfo["Ball3DBrain"].agents:
for agent_id in returned_braininfo[brain_name].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0

returned_braininfo = env.step()
brain_name = "Ball3DBrain"
returned_braininfo["Ball3DBrain"], returned_braininfo["Ball3DBrain"], {}
returned_braininfo[brain_name], returned_braininfo[brain_name], {}
returned_braininfo["Ball3DBrain"], returned_braininfo["Ball3DBrain"]
returned_braininfo[brain_name], returned_braininfo[brain_name]
for agent_id in returned_braininfo["Ball3DBrain"].agents:
for agent_id in returned_braininfo[brain_name].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0

31
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


def test_take_step_adds_experiences_to_trainer_and_trains():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
action_info_dict = {"testbrain": MagicMock()}
brain_info_dict = {"testbrain": Mock()}
brain_name = "testbrain"
action_info_dict = {brain_name: MagicMock()}
brain_info_dict = {brain_name: Mock()}
old_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
new_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
trainer_mock.is_ready_update = MagicMock(return_value=True)

env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info["testbrain"],
new_step_info.current_all_brain_info["testbrain"],
new_step_info.brain_name_to_action_info["testbrain"].outputs,
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
new_step_info.brain_name_to_action_info[brain_name].outputs,
new_step_info.previous_all_brain_info["testbrain"],
new_step_info.current_all_brain_info["testbrain"],
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.update_policy.assert_called_once()
trainer_mock.increment_step.assert_called_once()

tc, trainer_mock = trainer_controller_with_take_step_mocks()
tc.train_model = False
action_info_dict = {"testbrain": MagicMock()}
brain_name = "testbrain"
action_info_dict = {brain_name: MagicMock()}
brain_info_dict = {"testbrain": Mock()}
brain_info_dict = {brain_name: Mock()}
old_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
new_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)

env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info["testbrain"],
new_step_info.current_all_brain_info["testbrain"],
new_step_info.brain_name_to_action_info["testbrain"].outputs,
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
new_step_info.brain_name_to_action_info[brain_name].outputs,
new_step_info.previous_all_brain_info["testbrain"],
new_step_info.current_all_brain_info["testbrain"],
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.clear_update_buffer.assert_called_once()
正在加载...
取消
保存