|
|
|
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_training_behaviors_collects_results_from_all_envs( |
|
|
|
self, mock_create_worker |
|
|
|
): |
|
|
|
def create_worker_mock(worker_id, step_queue, env_factor, engine_c): |
|
|
|
return MockEnvWorker( |
|
|
|
worker_id, |
|
|
|
EnvironmentResponse( |
|
|
|
EnvironmentCommand.RESET, worker_id, {f"key{worker_id}": worker_id} |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
manager = SubprocessEnvManager( |
|
|
|
mock_env_factory, EngineConfig.default_config(), 4 |
|
|
|
) |
|
|
|
|
|
|
|
res = manager.training_behaviors |
|
|
|
for env in manager.env_workers: |
|
|
|
env.send.assert_called_with(EnvironmentCommand.BEHAVIOR_SPECS) |
|
|
|
env.recv.assert_called() |
|
|
|
for worker_id in range(4): |
|
|
|
assert f"key{worker_id}" in res |
|
|
|
assert res[f"key{worker_id}"] == worker_id |
|
|
|
|
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker): |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
manager = SubprocessEnvManager( |
|
|
|