Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

195 行
7.3 KiB

from unittest import mock
from unittest.mock import Mock, MagicMock
import unittest
import pytest
from queue import Empty as EmptyQueue
from mlagents.trainers.subprocess_env_manager import (
SubprocessEnvManager,
EnvironmentResponse,
StepResponse,
)
from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents.trainers.tests.simple_test_envs import Simple1DEnvironment
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.tests.test_simple_rl import (
_check_environment_trains,
PPO_CONFIG,
generate_config,
DebugWriter,
)
def mock_env_factory(worker_id):
return mock.create_autospec(spec=BaseEnv)
class MockEnvWorker:
def __init__(self, worker_id, resp=None):
self.worker_id = worker_id
self.process = None
self.conn = None
self.send = Mock()
self.recv = Mock(return_value=resp)
self.waiting = False
def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
return MockEnvWorker(worker_id, EnvironmentResponse("reset", worker_id, worker_id))
class SubprocessEnvManagerTest(unittest.TestCase):
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_environments_are_created(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
env = SubprocessEnvManager(mock_env_factory, EngineConfig.default_config(), 2)
# Creates two processes
env.create_worker.assert_has_calls(
[
mock.call(
0, env.step_queue, mock_env_factory, EngineConfig.default_config()
),
mock.call(
1, env.step_queue, mock_env_factory, EngineConfig.default_config()
),
]
)
self.assertEqual(len(env.env_workers), 2)
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_reset_passes_reset_params(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
manager = SubprocessEnvManager(
mock_env_factory, EngineConfig.default_config(), 1
)
params = {"test": "params"}
manager._reset_env(params)
manager.env_workers[0].send.assert_called_with("reset", (params))
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_reset_collects_results_from_all_envs(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
manager = SubprocessEnvManager(
mock_env_factory, EngineConfig.default_config(), 4
)
params = {"test": "params"}
res = manager._reset_env(params)
for i, env in enumerate(manager.env_workers):
env.send.assert_called_with("reset", (params))
env.recv.assert_called()
# Check that the "last steps" are set to the value returned for each step
self.assertEqual(
manager.env_workers[i].previous_step.current_all_step_result, i
)
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers))
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
manager = SubprocessEnvManager(
mock_env_factory, EngineConfig.default_config(), 3
)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse("step", 0, StepResponse(0, None)),
EnvironmentResponse("step", 1, StepResponse(1, None)),
EmptyQueue(),
]
step_mock = Mock()
last_steps = [Mock(), Mock(), Mock()]
manager.env_workers[0].previous_step = last_steps[0]
manager.env_workers[1].previous_step = last_steps[1]
manager.env_workers[2].previous_step = last_steps[2]
manager.env_workers[2].waiting = True
manager._take_step = Mock(return_value=step_mock)
res = manager._step()
for i, env in enumerate(manager.env_workers):
if i < 2:
env.send.assert_called_with("step", step_mock)
manager.step_queue.get_nowait.assert_called()
# Check that the "last steps" are set to the value returned for each step
self.assertEqual(
manager.env_workers[i].previous_step.current_all_step_result, i
)
assert res == [
manager.env_workers[0].previous_step,
manager.env_workers[1].previous_step,
]
@mock.patch("mlagents.trainers.subprocess_env_manager.SubprocessEnvManager._step")
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.external_brains",
new_callable=mock.PropertyMock,
)
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_advance(self, mock_create_worker, external_brains_mock, step_mock):
brain_name = "testbrain"
action_info_dict = {brain_name: MagicMock()}
mock_create_worker.side_effect = create_worker_mock
env_manager = SubprocessEnvManager(
mock_env_factory, EngineConfig.default_config(), 3
)
external_brains_mock.return_value = [brain_name]
agent_manager_mock = mock.Mock()
env_manager.set_agent_manager(brain_name, agent_manager_mock)
step_info_dict = {brain_name: Mock()}
step_info = EnvironmentStep(step_info_dict, 0, action_info_dict)
step_mock.return_value = [step_info]
env_manager.advance()
# Test add_experiences
env_manager._step.assert_called_once()
agent_manager_mock.add_experiences.assert_called_once_with(
step_info.current_all_step_result[brain_name],
0,
step_info.brain_name_to_action_info[brain_name],
)
# Test policy queue
mock_policy = mock.Mock()
agent_manager_mock.policy_queue.get_nowait.return_value = mock_policy
env_manager.advance()
assert env_manager.policies[brain_name] == mock_policy
assert agent_manager_mock.policy == mock_policy
def simple_env_factory(worker_id, config):
env = Simple1DEnvironment(["1D"], use_discrete=True)
return env
@pytest.mark.parametrize("num_envs", [1, 4])
def test_subprocess_env_endtoend(num_envs):
env_manager = SubprocessEnvManager(
simple_env_factory, EngineConfig.default_config(), num_envs
)
trainer_config = generate_config(PPO_CONFIG)
# Run PPO using env_manager
_check_environment_trains(
simple_env_factory(0, []),
trainer_config,
env_manager=env_manager,
success_threshold=None,
)
# Note we can't check the env's rewards directly (since they're in separate processes) so we
# check the StatsReporter's debug stat writer's last reward.
assert isinstance(StatsReporter.writers[0], DebugWriter)
assert all(
val > 0.99 for val in StatsReporter.writers[0].get_last_rewards().values()
)
env_manager.close()