|
|
|
|
|
|
import yaml |
|
|
|
from unittest import mock |
|
|
|
import pytest |
|
|
|
import mlagents.trainers.tests.mock_brain as mb |
|
|
|
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.tests.test_buffer import construct_fake_buffer |
|
|
|
|
|
|
|
|
|
|
# Add concrete implementations of abstract methods |
|
|
|
class FakeTrainer(RLTrainer): |
|
|
|
def set_is_policy_updating(self, is_updating): |
|
|
|
self.update_policy = is_updating |
|
|
|
|
|
|
|
def get_policy(self, name_behavior_id): |
|
|
|
return mock.Mock() |
|
|
|
|
|
|
|
|
|
|
def _update_policy(self): |
|
|
|
return True |
|
|
|
return self.update_policy |
|
|
|
|
|
|
|
def add_policy(self): |
|
|
|
pass |
|
|
|
|
|
|
def create_rl_trainer(): |
|
|
|
mock_brainparams = create_mock_brain() |
|
|
|
trainer = FakeTrainer(mock_brainparams, dummy_config(), True, 0) |
|
|
|
trainer.set_is_policy_updating(True) |
|
|
|
return trainer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_advance(mocked_clear_update_buffer): |
|
|
|
trainer = create_rl_trainer() |
|
|
|
trajectory_queue = AgentManagerQueue("testbrain") |
|
|
|
policy_queue = AgentManagerQueue("testbrain") |
|
|
|
time_horizon = 15 |
|
|
|
trainer.publish_policy_queue(policy_queue) |
|
|
|
time_horizon = 10 |
|
|
|
trajectory = mb.make_fake_trajectory( |
|
|
|
length=time_horizon, |
|
|
|
max_step_complete=True, |
|
|
|
|
|
|
trajectory_queue.put(trajectory) |
|
|
|
|
|
|
|
trainer.advance() |
|
|
|
policy_queue.get(block=False) |
|
|
|
for _ in range(0, 5): |
|
|
|
trajectory_queue.put(trajectory) |
|
|
|
trainer.advance() |
|
|
|
# Check that there is stuff in the policy queue |
|
|
|
policy_queue.get(block=False) |
|
|
|
|
|
|
|
# Check that if the policy doesn't update, we don't push it to the queue |
|
|
|
trainer.set_is_policy_updating(False) |
|
|
|
# Check that there nothing in the policy queue |
|
|
|
with pytest.raises(AgentManagerQueue.Empty): |
|
|
|
policy_queue.get(block=False) |
|
|
|
|
|
|
|
# Check that the buffer has been cleared |
|
|
|
assert not trainer.should_still_train |