|
|
|
|
|
|
import numpy as np |
|
|
|
import h5py |
|
|
|
from typing import List, BinaryIO |
|
|
|
|
|
|
|
from mlagents.envs.exception import UnityException |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.padding_value = 0 |
|
|
|
super(AgentBuffer.AgentBufferField, self).__init__() |
|
|
|
super().__init__() |
|
|
|
def append(self, element, padding_value=0): |
|
|
|
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None: |
|
|
|
""" |
|
|
|
Adds an element to this list. Also lets you change the padding |
|
|
|
type, so that it can be set on append (e.g. action_masks should |
|
|
|
|
|
|
""" |
|
|
|
super(AgentBuffer.AgentBufferField, self).append(element) |
|
|
|
super().append(element) |
|
|
|
def extend(self, data): |
|
|
|
def extend(self, data: np.ndarray) -> None: |
|
|
|
""" |
|
|
|
Adds a list of np.arrays to the end of the list of np.arrays. |
|
|
|
:param data: The np.array list to append. |
|
|
|
|
|
|
def set(self, data): |
|
|
|
def set(self, data: np.ndarray) -> None: |
|
|
|
""" |
|
|
|
Sets the list of np.array to the input data |
|
|
|
:param data: The np.array list to be set. |
|
|
|
|
|
|
|
|
|
|
def get_batch(self, batch_size=None, training_length=1, sequential=True): |
|
|
|
def get_batch( |
|
|
|
self, |
|
|
|
batch_size: int = None, |
|
|
|
training_length: int = 1, |
|
|
|
sequential: bool = True, |
|
|
|
) -> np.ndarray: |
|
|
|
""" |
|
|
|
Retrieve the last batch_size elements of length training_length |
|
|
|
from the list of np.array |
|
|
|
|
|
|
"The batch size and training length requested for get_batch where" |
|
|
|
" too large given the current number of data points." |
|
|
|
) |
|
|
|
tmp_list = [] |
|
|
|
tmp_list: List[np.ndarray] = [] |
|
|
|
def reset_field(self): |
|
|
|
def reset_field(self) -> None: |
|
|
|
""" |
|
|
|
Resets the AgentBufferField |
|
|
|
""" |
|
|
|
|
|
|
self.last_brain_info = None |
|
|
|
self.last_take_action_outputs = None |
|
|
|
super(AgentBuffer, self).__init__() |
|
|
|
super().__init__() |
|
|
|
def reset_agent(self): |
|
|
|
def reset_agent(self) -> None: |
|
|
|
""" |
|
|
|
Resets the AgentBuffer |
|
|
|
""" |
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
|
if key not in self.keys(): |
|
|
|
self[key] = self.AgentBufferField() |
|
|
|
return super(AgentBuffer, self).__getitem__(key) |
|
|
|
return super().__getitem__(key) |
|
|
|
def check_length(self, key_list): |
|
|
|
def check_length(self, key_list: List[str]) -> bool: |
|
|
|
""" |
|
|
|
Some methods will require that some fields have the same length. |
|
|
|
check_length will return true if the fields in key_list |
|
|
|
|
|
|
length = len(self[key]) |
|
|
|
return True |
|
|
|
|
|
|
|
def shuffle(self, sequence_length, key_list=None): |
|
|
|
def shuffle(self, sequence_length: int, key_list: List[str] = None) -> None: |
|
|
|
""" |
|
|
|
Shuffles the fields in key_list in a consistent way: The reordering will |
|
|
|
be the same across fields. |
|
|
|
|
|
|
s = np.arange(len(self[key_list[0]]) // sequence_length) |
|
|
|
np.random.shuffle(s) |
|
|
|
for key in key_list: |
|
|
|
tmp = [] |
|
|
|
tmp: List[np.ndarray] = [] |
|
|
|
def make_mini_batch(self, start, end): |
|
|
|
def make_mini_batch(self, start: int, end: int) -> "AgentBuffer": |
|
|
|
""" |
|
|
|
Creates a mini-batch from buffer. |
|
|
|
:param start: Starting index of buffer. |
|
|
|
|
|
|
mini_batch = {} |
|
|
|
mini_batch = AgentBuffer() |
|
|
|
def sample_mini_batch(self, batch_size, sequence_length=1): |
|
|
|
def sample_mini_batch( |
|
|
|
self, batch_size: int, sequence_length: int = 1 |
|
|
|
) -> "AgentBuffer": |
|
|
|
""" |
|
|
|
Creates a mini-batch from a random start and end. |
|
|
|
:param batch_size: number of elements to withdraw. |
|
|
|
|
|
|
mini_batch[key].extend(self[key][i : i + sequence_length]) |
|
|
|
return mini_batch |
|
|
|
|
|
|
|
def save_to_file(self, file_object): |
|
|
|
def save_to_file(self, file_object: BinaryIO) -> None: |
|
|
|
""" |
|
|
|
Saves the AgentBuffer to a file-like object. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
def load_from_file(self, file_object): |
|
|
|
def load_from_file(self, file_object: BinaryIO) -> None: |
|
|
|
""" |
|
|
|
Loads the AgentBuffer from a file-like object. |
|
|
|
""" |
|
|
|
|
|
|
# extend() will convert the numpy array's first dimension into list |
|
|
|
self[key].extend(read_file[key][()]) |
|
|
|
|
|
|
|
def truncate(self, max_length, sequence_length=1): |
|
|
|
def truncate(self, max_length: int, sequence_length: int = 1) -> None: |
|
|
|
""" |
|
|
|
Truncates the buffer to a certain length. |
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
|
if key not in self.keys(): |
|
|
|
self[key] = AgentBuffer() |
|
|
|
return super(AgentProcessorBuffer, self).__getitem__(key) |
|
|
|
return super().__getitem__(key) |
|
|
|
def reset_local_buffers(self): |
|
|
|
def reset_local_buffers(self) -> None: |
|
|
|
""" |
|
|
|
Resets all the local local_buffers |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
def append_update_buffer( |
|
|
|
self, |
|
|
|
update_buffer, |
|
|
|
agent_id, |
|
|
|
key_list=None, |
|
|
|
batch_size=None, |
|
|
|
training_length=None, |
|
|
|
): |
|
|
|
update_buffer: AgentBuffer, |
|
|
|
agent_id: str, |
|
|
|
key_list: List[str] = None, |
|
|
|
batch_size: int = None, |
|
|
|
training_length: int = None, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Appends the buffer of an agent to the update buffer. |
|
|
|
:param agent_id: The id of the agent which data will be appended |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def append_all_agent_batch_to_update_buffer( |
|
|
|
self, update_buffer, key_list=None, batch_size=None, training_length=None |
|
|
|
): |
|
|
|
self, |
|
|
|
update_buffer: AgentBuffer, |
|
|
|
key_list: List[str] = None, |
|
|
|
batch_size: int = None, |
|
|
|
training_length: int = None, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Appends the buffer of all agents to the update buffer. |
|
|
|
:param key_list: The fields that must be added. If None: all fields will be appended. |
|
|
|