您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
200 行
6.0 KiB
200 行
6.0 KiB
import numpy as np
|
|
from mlagents.trainers.buffer import (
|
|
AgentBuffer,
|
|
AgentBufferField,
|
|
BufferKey,
|
|
ObservationKeyPrefix,
|
|
RewardSignalKeyPrefix,
|
|
)
|
|
from mlagents.trainers.trajectory import ObsUtil
|
|
|
|
|
|
def assert_array(a, b):
|
|
assert a.shape == b.shape
|
|
la = list(a.flatten())
|
|
lb = list(b.flatten())
|
|
for i in range(len(la)):
|
|
assert la[i] == lb[i]
|
|
|
|
|
|
def construct_fake_buffer(fake_agent_id):
|
|
b = AgentBuffer()
|
|
for step in range(9):
|
|
b[ObsUtil.get_name_at(0)].append(
|
|
np.array(
|
|
[
|
|
100 * fake_agent_id + 10 * step + 1,
|
|
100 * fake_agent_id + 10 * step + 2,
|
|
100 * fake_agent_id + 10 * step + 3,
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
)
|
|
b[BufferKey.CONTINUOUS_ACTION].append(
|
|
np.array(
|
|
[
|
|
100 * fake_agent_id + 10 * step + 4,
|
|
100 * fake_agent_id + 10 * step + 5,
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
)
|
|
return b
|
|
|
|
|
|
def test_buffer():
|
|
agent_1_buffer = construct_fake_buffer(1)
|
|
agent_2_buffer = construct_fake_buffer(2)
|
|
agent_3_buffer = construct_fake_buffer(3)
|
|
a = agent_1_buffer[ObsUtil.get_name_at(0)].get_batch(
|
|
batch_size=2, training_length=1, sequential=True
|
|
)
|
|
assert_array(
|
|
np.array(a), np.array([[171, 172, 173], [181, 182, 183]], dtype=np.float32)
|
|
)
|
|
a = agent_2_buffer[ObsUtil.get_name_at(0)].get_batch(
|
|
batch_size=2, training_length=3, sequential=True
|
|
)
|
|
assert_array(
|
|
np.array(a),
|
|
np.array(
|
|
[
|
|
[231, 232, 233],
|
|
[241, 242, 243],
|
|
[251, 252, 253],
|
|
[261, 262, 263],
|
|
[271, 272, 273],
|
|
[281, 282, 283],
|
|
],
|
|
dtype=np.float32,
|
|
),
|
|
)
|
|
a = agent_2_buffer[ObsUtil.get_name_at(0)].get_batch(
|
|
batch_size=2, training_length=3, sequential=False
|
|
)
|
|
assert_array(
|
|
np.array(a),
|
|
np.array(
|
|
[
|
|
[251, 252, 253],
|
|
[261, 262, 263],
|
|
[271, 272, 273],
|
|
[261, 262, 263],
|
|
[271, 272, 273],
|
|
[281, 282, 283],
|
|
]
|
|
),
|
|
)
|
|
agent_1_buffer.reset_agent()
|
|
assert agent_1_buffer.num_experiences == 0
|
|
update_buffer = AgentBuffer()
|
|
agent_2_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
agent_3_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
assert len(update_buffer[BufferKey.CONTINUOUS_ACTION]) == 20
|
|
|
|
assert np.array(update_buffer[BufferKey.CONTINUOUS_ACTION]).shape == (20, 2)
|
|
|
|
c = update_buffer.make_mini_batch(start=0, end=1)
|
|
assert c.keys() == update_buffer.keys()
|
|
assert np.array(c[BufferKey.CONTINUOUS_ACTION]).shape == (1, 2)
|
|
|
|
|
|
def fakerandint(values):
|
|
return 19
|
|
|
|
|
|
def test_buffer_sample():
|
|
agent_1_buffer = construct_fake_buffer(1)
|
|
agent_2_buffer = construct_fake_buffer(2)
|
|
update_buffer = AgentBuffer()
|
|
agent_1_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
agent_2_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
# Test non-LSTM
|
|
mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1)
|
|
assert mb.keys() == update_buffer.keys()
|
|
assert np.array(mb[BufferKey.CONTINUOUS_ACTION]).shape == (4, 2)
|
|
|
|
# Test LSTM
|
|
# We need to check if we ever get a breaking start - this will maximize the probability
|
|
mb = update_buffer.sample_mini_batch(batch_size=20, sequence_length=19)
|
|
assert mb.keys() == update_buffer.keys()
|
|
# Should only return one sequence
|
|
assert np.array(mb[BufferKey.CONTINUOUS_ACTION]).shape == (19, 2)
|
|
|
|
|
|
def test_num_experiences():
|
|
agent_1_buffer = construct_fake_buffer(1)
|
|
agent_2_buffer = construct_fake_buffer(2)
|
|
update_buffer = AgentBuffer()
|
|
|
|
assert len(update_buffer[BufferKey.CONTINUOUS_ACTION]) == 0
|
|
assert update_buffer.num_experiences == 0
|
|
agent_1_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
agent_2_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
|
|
assert len(update_buffer[BufferKey.CONTINUOUS_ACTION]) == 20
|
|
assert update_buffer.num_experiences == 20
|
|
|
|
|
|
def test_buffer_truncate():
|
|
agent_1_buffer = construct_fake_buffer(1)
|
|
agent_2_buffer = construct_fake_buffer(2)
|
|
update_buffer = AgentBuffer()
|
|
agent_1_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
agent_2_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
# Test non-LSTM
|
|
update_buffer.truncate(2)
|
|
assert update_buffer.num_experiences == 2
|
|
|
|
agent_1_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
agent_2_buffer.resequence_and_append(
|
|
update_buffer, batch_size=None, training_length=2
|
|
)
|
|
# Test LSTM, truncate should be some multiple of sequence_length
|
|
update_buffer.truncate(4, sequence_length=3)
|
|
assert update_buffer.num_experiences == 3
|
|
for buffer_field in update_buffer.values():
|
|
assert isinstance(buffer_field, AgentBufferField)
|
|
|
|
|
|
def test_key_encode_decode():
|
|
keys = (
|
|
list(BufferKey)
|
|
+ [(k, 42) for k in ObservationKeyPrefix]
|
|
+ [(k, "gail") for k in RewardSignalKeyPrefix]
|
|
)
|
|
for k in keys:
|
|
assert k == AgentBuffer._decode_key(AgentBuffer._encode_key(k))
|
|
|
|
|
|
def test_buffer_save_load():
|
|
original = construct_fake_buffer(3)
|
|
import io
|
|
|
|
write_buffer = io.BytesIO()
|
|
original.save_to_file(write_buffer)
|
|
|
|
loaded = AgentBuffer()
|
|
loaded.load_from_file(write_buffer)
|
|
|
|
assert len(original) == len(loaded)
|
|
for k in original.keys():
|
|
assert np.allclose(original[k], loaded[k])
|