|
|
|
|
|
|
from unittest import mock |
|
|
|
from unittest.mock import Mock, MagicMock |
|
|
|
import unittest |
|
|
|
import pytest |
|
|
|
from queue import Empty as EmptyQueue |
|
|
|
|
|
|
|
from mlagents.trainers.subprocess_env_manager import ( |
|
|
|
|
|
|
from mlagents.trainers.env_manager import EnvironmentStep |
|
|
|
from mlagents_envs.base_env import BaseEnv |
|
|
|
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig |
|
|
|
from mlagents.trainers.tests.simple_test_envs import Simple1DEnvironment |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.tests.test_simple_rl import ( |
|
|
|
_check_environment_trains, |
|
|
|
PPO_CONFIG, |
|
|
|
generate_config, |
|
|
|
DebugWriter, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def mock_env_factory(worker_id): |
|
|
|
|
|
|
self.waiting = False |
|
|
|
|
|
|
|
|
|
|
|
def create_worker_mock(worker_id, step_queue, env_factor, engine_c): |
|
|
|
return MockEnvWorker(worker_id, EnvironmentResponse("reset", worker_id, worker_id)) |
|
|
|
|
|
|
|
|
|
|
|
def test_environments_are_created(self): |
|
|
|
SubprocessEnvManager.create_worker = MagicMock() |
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_environments_are_created(self, mock_create_worker): |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
env = SubprocessEnvManager(mock_env_factory, EngineConfig.default_config(), 2) |
|
|
|
# Creates two processes |
|
|
|
env.create_worker.assert_has_calls( |
|
|
|
|
|
|
) |
|
|
|
self.assertEqual(len(env.env_workers), 2) |
|
|
|
|
|
|
|
def test_reset_passes_reset_params(self): |
|
|
|
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker( |
|
|
|
worker_id, EnvironmentResponse("reset", worker_id, worker_id) |
|
|
|
) |
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_reset_passes_reset_params(self, mock_create_worker): |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
manager = SubprocessEnvManager( |
|
|
|
mock_env_factory, EngineConfig.default_config(), 1 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
def test_reset_collects_results_from_all_envs(self): |
|
|
|
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker( |
|
|
|
worker_id, EnvironmentResponse("reset", worker_id, worker_id) |
|
|
|
) |
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_reset_collects_results_from_all_envs(self, mock_create_worker): |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
manager = SubprocessEnvManager( |
|
|
|
mock_env_factory, EngineConfig.default_config(), 4 |
|
|
|
) |
|
|
|
|
|
|
) |
|
|
|
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers)) |
|
|
|
|
|
|
|
def test_step_takes_steps_for_all_non_waiting_envs(self): |
|
|
|
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker( |
|
|
|
worker_id, EnvironmentResponse("step", worker_id, worker_id) |
|
|
|
) |
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker): |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
manager = SubprocessEnvManager( |
|
|
|
mock_env_factory, EngineConfig.default_config(), 3 |
|
|
|
) |
|
|
|
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.external_brains", |
|
|
|
new_callable=mock.PropertyMock, |
|
|
|
) |
|
|
|
def test_advance(self, external_brains_mock, step_mock): |
|
|
|
@mock.patch( |
|
|
|
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker" |
|
|
|
) |
|
|
|
def test_advance(self, mock_create_worker, external_brains_mock, step_mock): |
|
|
|
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory, engine_c: MockEnvWorker( |
|
|
|
worker_id, EnvironmentResponse("step", worker_id, worker_id) |
|
|
|
) |
|
|
|
mock_create_worker.side_effect = create_worker_mock |
|
|
|
env_manager = SubprocessEnvManager( |
|
|
|
mock_env_factory, EngineConfig.default_config(), 3 |
|
|
|
) |
|
|
|
|
|
|
env_manager.advance() |
|
|
|
assert env_manager.policies[brain_name] == mock_policy |
|
|
|
assert agent_manager_mock.policy == mock_policy |
|
|
|
|
|
|
|
|
|
|
|
def simple_env_factory(worker_id, config): |
|
|
|
env = Simple1DEnvironment(["1D"], use_discrete=True) |
|
|
|
return env |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_envs", [1, 4]) |
|
|
|
def test_subprocess_env_endtoend(num_envs): |
|
|
|
env_manager = SubprocessEnvManager( |
|
|
|
simple_env_factory, EngineConfig.default_config(), num_envs |
|
|
|
) |
|
|
|
trainer_config = generate_config(PPO_CONFIG) |
|
|
|
# Run PPO using env_manager |
|
|
|
_check_environment_trains( |
|
|
|
simple_env_factory(0, []), |
|
|
|
trainer_config, |
|
|
|
env_manager=env_manager, |
|
|
|
success_threshold=None, |
|
|
|
) |
|
|
|
# Note we can't check the env's rewards directly (since they're in separate processes) so we |
|
|
|
# check the StatsReporter's debug stat writer's last reward. |
|
|
|
assert isinstance(StatsReporter.writers[0], DebugWriter) |
|
|
|
assert all( |
|
|
|
val > 0.99 for val in StatsReporter.writers[0].get_last_rewards().values() |
|
|
|
) |
|
|
|
env_manager.close() |