浏览代码

Fix slicing typing and string printing in AgentBufferField

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
50ab983e
共有 2 个文件被更改,包括 39 次插入3 次删除
  1. 13
      ml-agents/mlagents/trainers/buffer.py
  2. 29
      ml-agents/mlagents/trainers/tests/test_buffer.py

13
ml-agents/mlagents/trainers/buffer.py


When an agent collects a field, you can add it to its AgentBufferField with the append method.
"""
def __init__(self):
def __init__(self, *args, **kwargs):
super().__init__()
super().__init__(*args, **kwargs)
return str(np.array(self).shape)
return f"AgentBufferField: {super().__str__()}"
def __getitem__(self, index):
return_data = super().__getitem__(index)
if isinstance(return_data, list):
return AgentBufferField(return_data)
else:
return return_data
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
"""

29
ml-agents/mlagents/trainers/tests/test_buffer.py


c = update_buffer.make_mini_batch(start=0, end=1)
assert c.keys() == update_buffer.keys()
# Make sure the values of c are AgentBufferField
for val in c.values():
assert isinstance(val, AgentBufferField)
def test_agentbufferfield():
# Test constructor
a = AgentBufferField([0, 1, 2])
for i, num in enumerate(a):
assert num == i
# Test indexing
assert a[i] == num
# Test slicing
b = a[1:3]
assert b == [1, 2]
assert isinstance(b, AgentBufferField)
# Test padding
c = AgentBufferField()
for _ in range(2):
c.append([np.array(1), np.array(2)])
for _ in range(2):
c.append([np.array(1)])
padded = c.padded_to_batch(pad_value=3)
assert np.array_equal(padded[0], np.array([1, 1, 1, 1]))
assert np.array_equal(padded[1], np.array([2, 2, 3, 3]))
def fakerandint(values):

正在加载...
取消
保存