import unittest.mock as mock import pytest from unittest.mock import * from mlagents.trainers import learn, TrainerController @pytest.fixture def basic_options(): return { '--docker-target-name': 'None', '--env': 'None', '--run-id': 'ppo', '--load': False, '--train': False, '--save-freq': '50000', '--keep-checkpoints': '5', '--worker-id': '0', '--curriculum': 'None', '--lesson': '0', '--slow': False, '--no-graphics': False, '': 'basic_path', } @patch('mlagents.trainers.learn.init_environment') @patch('mlagents.trainers.learn.load_config') def test_run_training(load_config, init_environment): mock_env = MagicMock() mock_env.external_brain_names = [] mock_env.academy_name = 'TestAcademyName' init_environment.return_value = mock_env trainer_config_mock = MagicMock() load_config.return_value = trainer_config_mock mock_init = MagicMock(return_value=None) with patch.object(TrainerController, "__init__", mock_init): with patch.object(TrainerController, "start_learning", MagicMock()): learn.run_training(0, 0, basic_options(), MagicMock()) mock_init.assert_called_once_with( './models/ppo', './summaries', 'ppo-0', 50000, None, False, False, 5, 0, {}, 0 ) @patch('mlagents.trainers.learn.init_environment') @patch('mlagents.trainers.learn.load_config') def test_docker_target_path(load_config, init_environment): mock_env = MagicMock() mock_env.external_brain_names = [] mock_env.academy_name = 'TestAcademyName' init_environment.return_value = mock_env trainer_config_mock = MagicMock() load_config.return_value = trainer_config_mock options_with_docker_target = basic_options() options_with_docker_target['--docker-target-name'] = 'dockertarget' mock_init = MagicMock(return_value=None) with patch.object(TrainerController, "__init__", mock_init): with patch.object(TrainerController, "start_learning", MagicMock()): learn.run_training(0, 0, options_with_docker_target, MagicMock()) mock_init.assert_called_once() assert(mock_init.call_args[0][0] == '/dockertarget/models/ppo') assert(mock_init.call_args[0][1] == '/dockertarget/summaries')