浏览代码

Remove "external_brains" arg for TrainerController (#2213)

TrainerController depended on an external_brains dictionary with
brain params in its constructor but only used it in a single function
call.  The same function call (start_learning) takes the environment
as an argument, which is the source of the external_brains.

This change removes the dependency of TrainerController on external
brains and removes the two class members related to external_brains
and retrieves the brains directly from the environment.
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
966d8efb
共有 4 个文件被更改,包括 58 次插入53 次删除
  1. 1
      ml-agents/mlagents/trainers/learn.py
  2. 1
      ml-agents/mlagents/trainers/tests/test_learn.py
  3. 87
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  4. 22
      ml-agents/mlagents/trainers/trainer_controller.py

1
ml-agents/mlagents/trainers/learn.py


train_model,
keep_checkpoints,
lesson,
env.external_brains,
run_seed,
fast_simulation,
)

1
ml-agents/mlagents/trainers/tests/test_learn.py


False,
5,
0,
subproc_env_mock.return_value.external_brains,
0,
True,
)

87
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


@pytest.fixture
def basic_trainer_controller(brain_info):
def basic_trainer_controller():
return TrainerController(
model_path="test_model_path",
summaries_dir="test_summaries_dir",

train=True,
keep_checkpoints=False,
lesson=None,
external_brains={"testbrain": brain_info},
training_seed=99,
fast_simulation=True,
)

@patch("tensorflow.set_random_seed")
def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
seed = 27
TrainerController("", "", "1", 1, None, True, False, False, None, {}, seed, True)
TrainerController("", "", "1", 1, None, True, False, False, None, seed, True)
trainer_cls, input_config, tc, expected_brain_info, expected_config
trainer_cls, input_config, tc, expected_brain_params, expected_config
external_brains = {"testbrain": expected_brain_params}
assert brain == expected_brain_info
assert brain == expected_brain_params
assert trainer_params == expected_config
assert training == tc.train_model
assert load == tc.load_model

with patch.object(trainer_cls, "__init__", mock_constructor):
tc.initialize_trainers(input_config)
tc.initialize_trainers(input_config, external_brains)
input_config, tc, expected_brain_info, expected_config, expected_reward_buff_cap=0
input_config, tc, expected_brain_params, expected_config, expected_reward_buff_cap=0
external_brains = {"testbrain": expected_brain_params}
assert brain == expected_brain_info
assert brain == expected_brain_params
assert trainer_parameters == expected_config
assert reward_buff_cap == expected_reward_buff_cap
assert training == tc.train_model

with patch.object(PPOTrainer, "__init__", mock_constructor):
tc.initialize_trainers(input_config)
tc.initialize_trainers(input_config, external_brains)
@patch("mlagents.envs.BrainInfo")
def test_initialize_trainer_parameters_uses_defaults(BrainInfoMock):
brain_info_mock = BrainInfoMock()
tc = basic_trainer_controller(brain_info_mock)
@patch("mlagents.envs.BrainParameters")
def test_initialize_trainer_parameters_uses_defaults(BrainParametersMock):
brain_params_mock = BrainParametersMock()
tc = basic_trainer_controller()
full_config = dummy_offline_bc_config()
expected_config = full_config["default"]

assert_bc_trainer_constructed(
OfflineBCTrainer, full_config, tc, brain_info_mock, expected_config
OfflineBCTrainer, full_config, tc, brain_params_mock, expected_config
@patch("mlagents.envs.BrainInfo")
def test_initialize_trainer_parameters_override_defaults(BrainInfoMock):
brain_info_mock = BrainInfoMock()
tc = basic_trainer_controller(brain_info_mock)
@patch("mlagents.envs.BrainParameters")
def test_initialize_trainer_parameters_override_defaults(BrainParametersMock):
brain_params_mock = BrainParametersMock()
tc = basic_trainer_controller()
full_config = dummy_offline_bc_config_with_override()
expected_config = full_config["default"]

expected_config["normalize"] = False
assert_bc_trainer_constructed(
OfflineBCTrainer, full_config, tc, brain_info_mock, expected_config
OfflineBCTrainer, full_config, tc, brain_params_mock, expected_config
@patch("mlagents.envs.BrainInfo")
def test_initialize_online_bc_trainer(BrainInfoMock):
brain_info_mock = BrainInfoMock()
tc = basic_trainer_controller(brain_info_mock)
@patch("mlagents.envs.BrainParameters")
def test_initialize_online_bc_trainer(BrainParametersMock):
brain_params_mock = BrainParametersMock()
tc = basic_trainer_controller()
full_config = dummy_online_bc_config()
expected_config = full_config["default"]

assert_bc_trainer_constructed(
OnlineBCTrainer, full_config, tc, brain_info_mock, expected_config
OnlineBCTrainer, full_config, tc, brain_params_mock, expected_config
@patch("mlagents.envs.BrainInfo")
def test_initialize_ppo_trainer(BrainInfoMock):
brain_info_mock = BrainInfoMock()
tc = basic_trainer_controller(brain_info_mock)
@patch("mlagents.envs.BrainParameters")
def test_initialize_ppo_trainer(BrainParametersMock):
brain_params_mock = BrainParametersMock()
tc = basic_trainer_controller()
full_config = dummy_config()
expected_config = full_config["default"]

assert_ppo_trainer_constructed(full_config, tc, brain_info_mock, expected_config)
assert_ppo_trainer_constructed(full_config, tc, brain_params_mock, expected_config)
@patch("mlagents.envs.BrainInfo")
def test_initialize_invalid_trainer_raises_exception(BrainInfoMock):
brain_info_mock = BrainInfoMock()
tc = basic_trainer_controller(brain_info_mock)
@patch("mlagents.envs.BrainParameters")
def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
tc = basic_trainer_controller()
external_brains = {"testbrain": BrainParametersMock()}
tc.initialize_trainers(bad_config)
tc.initialize_trainers(bad_config, external_brains)
def trainer_controller_with_start_learning_mocks():

trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()
brain_info_mock = MagicMock()
tc = basic_trainer_controller(brain_info_mock)
tc = basic_trainer_controller()
tc.initialize_trainers = MagicMock()
tc.trainers = {"testbrain": trainer_mock}
tc.advance = MagicMock()

env_mock = MagicMock()
env_mock.close = MagicMock()
env_mock.reset = MagicMock()
env_mock.external_brains = MagicMock()
tc.initialize_trainers.assert_called_once_with(trainer_config)
tc.initialize_trainers.assert_called_once_with(
trainer_config, env_mock.external_brains
)
env_mock.reset.assert_called_once()
assert tc.advance.call_count == 11
tc._export_graph.assert_not_called()

env_mock = MagicMock()
env_mock.close = MagicMock()
env_mock.reset = MagicMock(return_value=brain_info_mock)
env_mock.external_brains = MagicMock()
tc.initialize_trainers.assert_called_once_with(trainer_config)
tc.initialize_trainers.assert_called_once_with(
trainer_config, env_mock.external_brains
)
env_mock.reset.assert_called_once()
assert tc.advance.call_count == trainer_mock.get_max_steps + 1
env_mock.close.assert_called_once()

trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()
brain_info_mock = MagicMock()
tc = basic_trainer_controller(brain_info_mock)
tc = basic_trainer_controller()
tc.trainers = {"testbrain": trainer_mock}
return tc, trainer_mock

22
ml-agents/mlagents/trainers/trainer_controller.py


train: bool,
keep_checkpoints: int,
lesson: Optional[int],
external_brains: Dict[str, BrainParameters],
training_seed: int,
fast_simulation: bool,
):

:param train: Whether to train model, or only run inference.
:param keep_checkpoints: How many model checkpoints to keep.
:param lesson: Start learning from this lesson.
:param external_brains: dictionary of external brain names to BrainInfo objects.
self.external_brains = external_brains
self.external_brain_names = external_brains.keys()
self.logger = logging.getLogger("mlagents.envs")
self.run_id = run_id
self.save_freq = save_freq

for brain_name in self.trainers.keys():
self.trainers[brain_name].export_model()
def initialize_trainers(self, trainer_config: Dict[str, Any]) -> None:
def initialize_trainers(
self,
trainer_config: Dict[str, Any],
external_brains: Dict[str, BrainParameters],
) -> None:
for brain_name in self.external_brains:
for brain_name in external_brains:
trainer_parameters = trainer_config["default"].copy()
trainer_parameters["summary_path"] = "{basedir}/{name}".format(
basedir=self.summaries_dir, name=str(self.run_id) + "_" + brain_name

_brain_key = trainer_config[_brain_key]
trainer_parameters.update(trainer_config[_brain_key])
trainer_parameters_dict[brain_name] = trainer_parameters.copy()
for brain_name in self.external_brains:
for brain_name in external_brains:
self.external_brains[brain_name],
external_brains[brain_name],
trainer_parameters_dict[brain_name],
self.train_model,
self.load_model,

elif trainer_parameters_dict[brain_name]["trainer"] == "online_bc":
self.trainers[brain_name] = OnlineBCTrainer(
self.external_brains[brain_name],
external_brains[brain_name],
trainer_parameters_dict[brain_name],
self.train_model,
self.load_model,

elif trainer_parameters_dict[brain_name]["trainer"] == "ppo":
self.trainers[brain_name] = PPOTrainer(
self.external_brains[brain_name],
external_brains[brain_name],
self.meta_curriculum.brains_to_curriculums[
brain_name
].min_lesson_length

tf.reset_default_graph()
# Prevent a single session from taking all GPU memory.
self.initialize_trainers(trainer_config)
self.initialize_trainers(trainer_config, env.external_brains)
for _, t in self.trainers.items():
self.logger.info(t)

正在加载...
取消
保存