|
|
|
|
|
|
b[4].reset_agent() |
|
|
|
assert len(b[4]) == 0 |
|
|
|
update_buffer = AgentBuffer() |
|
|
|
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
assert len(update_buffer["action"]) == 20 |
|
|
|
assert np.array(update_buffer["action"]).shape == (20, 2) |
|
|
|
|
|
|
|
|
|
|
def test_buffer_sample(): |
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
update_buffer = AgentBuffer() |
|
|
|
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
# Test non-LSTM |
|
|
|
mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1) |
|
|
|
assert mb.keys() == update_buffer.keys() |
|
|
|
|
|
|
def test_buffer_truncate(): |
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
update_buffer = AgentBuffer() |
|
|
|
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
# Test LSTM, truncate should be some multiple of sequence_length |
|
|
|
update_buffer.truncate(4, sequence_length=3) |
|
|
|
assert len(update_buffer["action"]) == 3 |