|
|
|
|
|
|
import numpy as np |
|
|
|
import h5py |
|
|
|
from typing import List, BinaryIO, Any |
|
|
|
from typing import List, BinaryIO, Any, Union |
|
|
|
|
|
|
|
# Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards, |
|
|
|
# a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references. |
|
|
|
BufferEntry = Union[np.ndarray, List[np.ndarray]] |
|
|
|
|
|
|
|
|
|
|
|
class BufferException(UnityException): |
|
|
|
|
|
|
def __str__(self): |
|
|
|
return str(np.array(self).shape) |
|
|
|
|
|
|
|
def append(self, element: Any, padding_value: Any = 0.0) -> None: |
|
|
|
def append(self, element: BufferEntry, padding_value: Any = 0.0) -> None: |
|
|
|
""" |
|
|
|
Adds an element to this AgentBuffer. Also lets you change the padding |
|
|
|
type, so that it can be set on append (e.g. action_masks should |
|
|
|
|
|
|
super().append(element) |
|
|
|
self.padding_value = padding_value |
|
|
|
|
|
|
|
def set(self, data: List[Any]) -> None: |
|
|
|
def set(self, data: List[BufferEntry]) -> None: |
|
|
|
""" |
|
|
|
Sets the AgentBuffer to the provided list |
|
|
|
:param data: The list to be set. |
|
|
|
|
|
|
batch_size: int = None, |
|
|
|
training_length: int = 1, |
|
|
|
sequential: bool = True, |
|
|
|
) -> List[Any]: |
|
|
|
) -> List[BufferEntry]: |
|
|
|
""" |
|
|
|
Retrieve the last batch_size elements of length training_length |
|
|
|
from the AgentBuffer. |
|
|
|
|
|
|
if key not in self.keys(): |
|
|
|
return False |
|
|
|
if (length is not None) and (length != len(self[key])): |
|
|
|
print(length, key, len(self[key])) |
|
|
|
return False |
|
|
|
length = len(self[key]) |
|
|
|
return True |
|
|
|
|
|
|
return len(next(iter(self.values()))) |
|
|
|
else: |
|
|
|
return 0 |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def obs_list_to_obs_batch(obs_list: List[List[np.ndarray]]) -> List[np.ndarray]: |
|
|
|
""" |
|
|
|
Converts a List of obs (an obs itself consinsting of a List of np.ndarray) to |
|
|
|
a List of np.ndarray, with the observations batchwise. |
|
|
|
""" |
|
|
|
# Transpose and convert List of Lists |
|
|
|
new_list = list(map(lambda x: np.asanyarray(list(x)), zip(*obs_list))) |
|
|
|
return new_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def obs_list_list_to_obs_batch( |
|
|
|
obs_list_list: List[List[List[np.ndarray]]] |
|
|
|
) -> List[List[np.ndarray]]: |
|
|
|
""" |
|
|
|
Convert a List of List of obs, where one of the dimension is time and the other is number (e.g. in the |
|
|
|
case of a variable number of critic observations) to a List of obs, where time is in the batch dimension |
|
|
|
of the obs, and the List is the variable number of agents. |
|
|
|
""" |
|
|
|
new_list = list( |
|
|
|
map( |
|
|
|
lambda x: AgentBuffer.obs_list_to_obs_batch(list(x)), |
|
|
|
zip(*obs_list_list), |
|
|
|
) |
|
|
|
) |
|
|
|
return new_list |