浏览代码

Make buffer type-agnostic

/develop/multitype-buffer
Ervin Teng 4 年前
当前提交
184f27c6
共有 3 个文件被更改,包括 36 次插入39 次删除
  1. 40
      ml-agents/mlagents/trainers/buffer.py
  2. 6
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 29
      ml-agents/mlagents/trainers/tests/test_buffer.py

40
ml-agents/mlagents/trainers/buffer.py


import numpy as np
import h5py
from typing import List, BinaryIO
from typing import List, BinaryIO, Any
import itertools
from mlagents_envs.exception import UnityException

def __str__(self):
return str(np.array(self).shape)
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
def append(self, element: Any, padding_value: float = 0.0) -> None:
Adds an element to this list. Also lets you change the padding
Adds an element to this AgentBuffer. Also lets you change the padding
type, so that it can be set on append (e.g. action_masks should
be padded with 1.)
:param element: The element to append to the list.

self.padding_value = padding_value
def extend(self, data: np.ndarray) -> None:
"""
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
self += list(np.array(data, dtype=np.float32))
def set(self, data):
def set(self, data: List[Any]) -> None:
Sets the list of np.array to the input data
:param data: The np.array list to be set.
Sets the AgentBuffer to the provided list
:param data: The list to be set.
dtype = None
if data is not None and len(data) and isinstance(data[0], float):
dtype = np.float32
self[:] = list(np.array(data, dtype=dtype))
self[:] = data
def get_batch(
self,

) -> np.ndarray:
) -> List[Any]:
from the list of np.array
from the AgentBuffer.
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If

)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:],
dtype=np.float32,
)
return [padding] * (training_length - leftover) + self[:]
return np.array(
self[len(self) - batch_size * training_length :],
dtype=np.float32,
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:

tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list
def reset_field(self) -> None:
"""

6
ml-agents/mlagents/trainers/ppo/trainer.py


n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer["advantages"].get_batch()
# Normalize advantages
advantages = np.array(self.update_buffer["advantages"].get_batch())
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
list((advantages - advantages.mean()) / (advantages.std() + 1e-10))
)
num_epoch = self.hyperparameters.num_epoch
batch_update_stats = defaultdict(list)

29
ml-agents/mlagents/trainers/tests/test_buffer.py


b = AgentBuffer()
for step in range(9):
b["vector_observation"].append(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,
100 * fake_agent_id + 10 * step + 3,
]
np.array(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,
100 * fake_agent_id + 10 * step + 3,
],
dtype=np.float32,
)
[100 * fake_agent_id + 10 * step + 4, 100 * fake_agent_id + 10 * step + 5]
np.array(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,
],
dtype=np.float32,
)
)
return b

a = agent_1_buffer["vector_observation"].get_batch(
batch_size=2, training_length=1, sequential=True
)
assert_array(np.array(a), np.array([[171, 172, 173], [181, 182, 183]]))
assert len(a) == 2
assert_array(
np.array(a), np.array([[171, 172, 173], [181, 182, 183]], dtype=np.float32)
)
a = agent_2_buffer["vector_observation"].get_batch(
batch_size=2, training_length=3, sequential=True
)

[261, 262, 263],
[271, 272, 273],
[281, 282, 283],
]
],
dtype=np.float32,
),
)
a = agent_2_buffer["vector_observation"].get_batch(

正在加载...
取消
保存