|
|
|
|
|
|
|
|
|
|
c = update_buffer.make_mini_batch(start=0, end=1) |
|
|
|
assert c.keys() == update_buffer.keys() |
|
|
|
# Make sure the values of c are AgentBufferField |
|
|
|
for val in c.values(): |
|
|
|
assert isinstance(val, AgentBufferField) |
|
|
|
|
|
|
|
|
|
|
|
def test_agentbufferfield(): |
|
|
|
# Test constructor |
|
|
|
a = AgentBufferField([0, 1, 2]) |
|
|
|
for i, num in enumerate(a): |
|
|
|
assert num == i |
|
|
|
# Test indexing |
|
|
|
assert a[i] == num |
|
|
|
|
|
|
|
# Test slicing |
|
|
|
b = a[1:3] |
|
|
|
assert b == [1, 2] |
|
|
|
assert isinstance(b, AgentBufferField) |
|
|
|
|
|
|
|
# Test padding |
|
|
|
c = AgentBufferField() |
|
|
|
for _ in range(2): |
|
|
|
c.append([np.array(1), np.array(2)]) |
|
|
|
|
|
|
|
for _ in range(2): |
|
|
|
c.append([np.array(1)]) |
|
|
|
|
|
|
|
padded = c.padded_to_batch(pad_value=3) |
|
|
|
assert np.array_equal(padded[0], np.array([1, 1, 1, 1])) |
|
|
|
assert np.array_equal(padded[1], np.array([2, 2, 3, 3])) |
|
|
|
|
|
|
|
|
|
|
|
def fakerandint(values): |
|
|
|