浏览代码

Make buffer typing neater

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
3d0abb03
共有 1 个文件被更改,包括 8 次插入32 次删除
  1. 40
      ml-agents/mlagents/trainers/buffer.py

40
ml-agents/mlagents/trainers/buffer.py


import numpy as np
import h5py
from typing import List, BinaryIO, Any
from typing import List, BinaryIO, Any, Union
# Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards,
# a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references.
BufferEntry = Union[np.ndarray, List[np.ndarray]]
class BufferException(UnityException):

def __str__(self):
return str(np.array(self).shape)
def append(self, element: Any, padding_value: Any = 0.0) -> None:
def append(self, element: BufferEntry, padding_value: Any = 0.0) -> None:
"""
Adds an element to this AgentBuffer. Also lets you change the padding
type, so that it can be set on append (e.g. action_masks should

super().append(element)
self.padding_value = padding_value
def set(self, data: List[Any]) -> None:
def set(self, data: List[BufferEntry]) -> None:
"""
Sets the AgentBuffer to the provided list
:param data: The list to be set.

batch_size: int = None,
training_length: int = 1,
sequential: bool = True,
) -> List[Any]:
) -> List[BufferEntry]:
"""
Retrieve the last batch_size elements of length training_length
from the AgentBuffer.

if key not in self.keys():
return False
if (length is not None) and (length != len(self[key])):
print(length, key, len(self[key]))
return False
length = len(self[key])
return True

return len(next(iter(self.values())))
else:
return 0
@staticmethod
def obs_list_to_obs_batch(obs_list: List[List[np.ndarray]]) -> List[np.ndarray]:
"""
Converts a List of obs (an obs itself consinsting of a List of np.ndarray) to
a List of np.ndarray, with the observations batchwise.
"""
# Transpose and convert List of Lists
new_list = list(map(lambda x: np.asanyarray(list(x)), zip(*obs_list)))
return new_list
@staticmethod
def obs_list_list_to_obs_batch(
obs_list_list: List[List[List[np.ndarray]]]
) -> List[List[np.ndarray]]:
"""
Convert a List of List of obs, where one of the dimension is time and the other is number (e.g. in the
case of a variable number of critic observations) to a List of obs, where time is in the batch dimension
of the obs, and the List is the variable number of agents.
"""
new_list = list(
map(
lambda x: AgentBuffer.obs_list_to_obs_batch(list(x)),
zip(*obs_list_list),
)
)
return new_list
正在加载...
取消
保存