|
|
|
|
|
|
use_recurrent: false |
|
|
|
sequence_length: 32 |
|
|
|
memory_size: 32 |
|
|
|
batches_per_epoch: 1 |
|
|
|
batches_per_epoch: 100 # Force code to use all possible batches |
|
|
|
batch_size: 32 |
|
|
|
summary_freq: 2000 |
|
|
|
max_steps: 4000 |
|
|
|
|
|
|
|
|
|
|
def create_bc_trainer(dummy_config, is_discrete=False): |
|
|
|
def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False): |
|
|
|
mock_env = mock.Mock() |
|
|
|
if is_discrete: |
|
|
|
mock_brain = mb.create_mock_pushblock_brain() |
|
|
|
|
|
|
trainer_parameters["demo_path"] = ( |
|
|
|
os.path.dirname(os.path.abspath(__file__)) + "/test.demo" |
|
|
|
) |
|
|
|
trainer_parameters["use_recurrent"] = use_recurrent |
|
|
|
trainer = BCTrainer( |
|
|
|
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
def test_bc_trainer_step(dummy_config): |
|
|
|
trainer, env = create_bc_trainer(dummy_config) |
|
|
|
@pytest.mark.parametrize("use_recurrent", [True, False]) |
|
|
|
def test_bc_trainer_step(dummy_config, use_recurrent): |
|
|
|
trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent) |
|
|
|
# Test get_step |
|
|
|
assert trainer.get_step == 0 |
|
|
|
# Test update policy |
|
|
|