|
|
|
|
|
|
import numpy as np |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.agent_processor import ProcessingBuffer |
|
|
|
|
|
|
|
|
|
|
|
def assert_array(a, b): |
|
|
|
|
|
|
assert la[i] == lb[i] |
|
|
|
|
|
|
|
|
|
|
|
def construct_fake_processing_buffer(): |
|
|
|
b = ProcessingBuffer() |
|
|
|
for fake_agent_id in range(4): |
|
|
|
for step in range(9): |
|
|
|
b[fake_agent_id]["vector_observation"].append( |
|
|
|
[ |
|
|
|
100 * fake_agent_id + 10 * step + 1, |
|
|
|
100 * fake_agent_id + 10 * step + 2, |
|
|
|
100 * fake_agent_id + 10 * step + 3, |
|
|
|
] |
|
|
|
) |
|
|
|
b[fake_agent_id]["action"].append( |
|
|
|
[ |
|
|
|
100 * fake_agent_id + 10 * step + 4, |
|
|
|
100 * fake_agent_id + 10 * step + 5, |
|
|
|
] |
|
|
|
) |
|
|
|
def construct_fake_buffer(fake_agent_id): |
|
|
|
b = AgentBuffer() |
|
|
|
for step in range(9): |
|
|
|
b["vector_observation"].append( |
|
|
|
[ |
|
|
|
100 * fake_agent_id + 10 * step + 1, |
|
|
|
100 * fake_agent_id + 10 * step + 2, |
|
|
|
100 * fake_agent_id + 10 * step + 3, |
|
|
|
] |
|
|
|
) |
|
|
|
b["action"].append( |
|
|
|
[100 * fake_agent_id + 10 * step + 4, 100 * fake_agent_id + 10 * step + 5] |
|
|
|
) |
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
a = b[1]["vector_observation"].get_batch( |
|
|
|
agent_1_buffer = construct_fake_buffer(1) |
|
|
|
agent_2_buffer = construct_fake_buffer(2) |
|
|
|
agent_3_buffer = construct_fake_buffer(3) |
|
|
|
a = agent_1_buffer["vector_observation"].get_batch( |
|
|
|
a = b[2]["vector_observation"].get_batch( |
|
|
|
a = agent_2_buffer["vector_observation"].get_batch( |
|
|
|
batch_size=2, training_length=3, sequential=True |
|
|
|
) |
|
|
|
assert_array( |
|
|
|
|
|
|
] |
|
|
|
), |
|
|
|
) |
|
|
|
a = b[2]["vector_observation"].get_batch( |
|
|
|
a = agent_2_buffer["vector_observation"].get_batch( |
|
|
|
batch_size=2, training_length=3, sequential=False |
|
|
|
) |
|
|
|
assert_array( |
|
|
|
|
|
|
] |
|
|
|
), |
|
|
|
) |
|
|
|
b[4].reset_agent() |
|
|
|
assert len(b[4]) == 0 |
|
|
|
agent_1_buffer.reset_agent() |
|
|
|
assert agent_1_buffer.num_experiences == 0 |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
agent_2_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
agent_3_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
assert len(update_buffer["action"]) == 20 |
|
|
|
|
|
|
|
assert np.array(update_buffer["action"]).shape == (20, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_buffer_sample(): |
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
agent_1_buffer = construct_fake_buffer(1) |
|
|
|
agent_2_buffer = construct_fake_buffer(2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
agent_1_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
agent_2_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
# Test non-LSTM |
|
|
|
mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1) |
|
|
|
assert mb.keys() == update_buffer.keys() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_num_experiences(): |
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
agent_1_buffer = construct_fake_buffer(1) |
|
|
|
agent_2_buffer = construct_fake_buffer(2) |
|
|
|
|
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
agent_1_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
agent_2_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
|
|
|
|
assert len(update_buffer["action"]) == 20 |
|
|
|
assert update_buffer.num_experiences == 20 |
|
|
|
|
|
|
b = construct_fake_processing_buffer() |
|
|
|
agent_1_buffer = construct_fake_buffer(1) |
|
|
|
agent_2_buffer = construct_fake_buffer(2) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
agent_1_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
agent_2_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) |
|
|
|
b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) |
|
|
|
agent_1_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
agent_2_buffer.resequence_and_append( |
|
|
|
update_buffer, batch_size=None, training_length=2 |
|
|
|
) |
|
|
|
# Test LSTM, truncate should be some multiple of sequence_length |
|
|
|
update_buffer.truncate(4, sequence_length=3) |
|
|
|
assert update_buffer.num_experiences == 3 |