浏览代码

Buffer fixes

(cherry picked from commit 2c03d2b544d0c615e7b60d939f01532674d80753)
/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
2f209c12
共有 3 个文件被更改,包括 35 次插入15 次删除
  1. 25
      ml-agents/mlagents/trainers/buffer.py
  2. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 23
      ml-agents/mlagents/trainers/tests/test_buffer.py

25
ml-agents/mlagents/trainers/buffer.py


class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to its
AgentBufferField with the append method.
AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries.
When an agent collects a field, you can add it to its AgentBufferField with the append method.
"""
def __init__(self):

def __str__(self):
def __str__(self) -> str:
return str(np.array(self).shape)
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:

super().append(element)
self.padding_value = padding_value
def set(self, data):
def set(self, data: List[BufferEntry]) -> None:
Sets the list of np.array to the input data
:param data: The np.array list to be set.
Sets the list of BufferEntry to the input data
:param data: The BufferEntry list to be set.
"""
self[:] = []
self[:] = data

batch_size: int = None,
training_length: Optional[int] = 1,
sequential: bool = True,
) -> np.ndarray:
) -> List[BufferEntry]:
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array

)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:], dtype=np.float32
)
return [padding] * (training_length - leftover) + self[:]
return np.array(
self[len(self) - batch_size * training_length :], dtype=np.float32
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:

tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list
def reset_field(self) -> None:
"""

2
ml-agents/mlagents/trainers/ppo/trainer.py


int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
advantages = np.array(self.update_buffer[BufferKey.ADVANTAGES].get_batch())
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)

23
ml-agents/mlagents/trainers/tests/test_buffer.py


dtype=np.float32,
)
)
b[BufferKey.GROUP_CONTINUOUS_ACTION].append(
[
np.array(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,
],
dtype=np.float32,
)
]
* 3
)
return b

agent_3_buffer = construct_fake_buffer(3)
# Test get_batch
a = agent_1_buffer[ObsUtil.get_name_at(0)].get_batch(
batch_size=2, training_length=1, sequential=True
)

# Test get_batch
a = agent_2_buffer[ObsUtil.get_name_at(0)].get_batch(
batch_size=2, training_length=3, sequential=True
)

]
),
)
# Test group entries return Lists of Lists
a = agent_2_buffer[BufferKey.GROUP_CONTINUOUS_ACTION].get_batch(
batch_size=2, training_length=1, sequential=True
)
for _group_entry in a:
assert len(_group_entry) == 3
agent_1_buffer.reset_agent()
assert agent_1_buffer.num_experiences == 0
update_buffer = AgentBuffer()

正在加载...
取消
保存