|
|
|
|
|
|
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.tests.test_buffer import construct_fake_buffer |
|
|
|
from mlagents.trainers.agent_processor import AgentManagerQueue |
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
from mlagents.trainers.settings import TrainerSettings, FrameworkType |
|
|
|
|
|
|
|
|
|
|
|
# Add concrete implementations of abstract methods |
|
|
|
|
|
|
super()._process_trajectory(trajectory) |
|
|
|
|
|
|
|
|
|
|
|
def create_rl_trainer(): |
|
|
|
def create_rl_trainer(framework=FrameworkType.TENSORFLOW): |
|
|
|
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20), |
|
|
|
TrainerSettings( |
|
|
|
max_steps=100, checkpoint_interval=10, summary_freq=20, framework=framework |
|
|
|
), |
|
|
|
True, |
|
|
|
False, |
|
|
|
"mock_model_path", |
|
|
|
|
|
|
assert mocked_save_model.call_count == 0 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"] |
|
|
|
) |
|
|
|
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary): |
|
|
|
trainer = create_rl_trainer() |
|
|
|
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework): |
|
|
|
trainer = create_rl_trainer(framework) |
|
|
|
mock_policy = mock.Mock() |
|
|
|
trainer.add_policy("TestBrain", mock_policy) |
|
|
|
trajectory_queue = AgentManagerQueue("testbrain") |
|
|
|
|
|
|
) |
|
|
|
calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range] |
|
|
|
trainer.saver.save_checkpoint.assert_has_calls(calls, any_order=True) |
|
|
|
export_ext = "nn" if trainer.framework == FrameworkType.TENSORFLOW else "onnx" |
|
|
|
|
|
|
|
add_checkpoint_calls = [ |
|
|
|
mock.call( |
|
|
|
|
|
|
f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.nn", |
|
|
|
f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.{export_ext}", |
|
|
|
None, |
|
|
|
mock.ANY, |
|
|
|
), |
|
|
|