|
|
|
|
|
|
import mlagents.trainers.tests.mock_brain as mb |
|
|
|
from mlagents.trainers.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.tests.test_buffer import construct_fake_buffer |
|
|
|
from mlagents.trainers.agent_processor import AgentManagerQueue |
|
|
|
|
|
|
|
|
|
|
|
def dummy_config(): |
|
|
|
|
|
|
summary_freq: 1000 |
|
|
|
max_steps: 100 |
|
|
|
reward_signals: |
|
|
|
extrinsic: |
|
|
|
strength: 1.0 |
|
|
|
|
|
|
trainer.clear_update_buffer() |
|
|
|
for _, arr in trainer.update_buffer.items(): |
|
|
|
assert len(arr) == 0 |
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.clear_update_buffer") |
|
|
|
def test_advance(mocked_clear_update_buffer): |
|
|
|
trainer = create_rl_trainer() |
|
|
|
trajectory_queue = AgentManagerQueue("testbrain") |
|
|
|
trainer.subscribe_trajectory_queue(trajectory_queue) |
|
|
|
time_horizon = 15 |
|
|
|
trajectory = mb.make_fake_trajectory( |
|
|
|
length=time_horizon, |
|
|
|
max_step_complete=True, |
|
|
|
vec_obs_size=1, |
|
|
|
num_vis_obs=0, |
|
|
|
action_space=[2], |
|
|
|
) |
|
|
|
trajectory_queue.put(trajectory) |
|
|
|
|
|
|
|
trainer.advance() |
|
|
|
# Check that get_step is correct |
|
|
|
assert trainer.get_step == time_horizon |
|
|
|
# Check that we can turn off the trainer and that the buffer is cleared |
|
|
|
for _ in range(0, 10): |
|
|
|
trajectory_queue.put(trajectory) |
|
|
|
trainer.advance() |
|
|
|
|
|
|
|
# Check that the buffer has been cleared |
|
|
|
assert not trainer.should_still_train |
|
|
|
assert mocked_clear_update_buffer.call_count > 0 |