浏览代码

Add type hints to Buffer

/develop-newnormalization
Ervin Teng 5 年前
当前提交
c2d216ca
共有 1 个文件被更改,包括 42 次插入30 次删除
  1. 72
      ml-agents/mlagents/trainers/buffer.py

72
ml-agents/mlagents/trainers/buffer.py


import numpy as np
import h5py
from typing import List, BinaryIO
from mlagents.envs.exception import UnityException

def __init__(self):
self.padding_value = 0
super(AgentBuffer.AgentBufferField, self).__init__()
super().__init__()
def append(self, element, padding_value=0):
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
"""
Adds an element to this list. Also lets you change the padding
type, so that it can be set on append (e.g. action_masks should

"""
super(AgentBuffer.AgentBufferField, self).append(element)
super().append(element)
def extend(self, data):
def extend(self, data: np.ndarray) -> None:
"""
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.

def set(self, data):
def set(self, data: np.ndarray) -> None:
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.

def get_batch(self, batch_size=None, training_length=1, sequential=True):
def get_batch(
self,
batch_size: int = None,
training_length: int = 1,
sequential: bool = True,
) -> np.ndarray:
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array

"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
tmp_list = []
tmp_list: List[np.ndarray] = []
def reset_field(self):
def reset_field(self) -> None:
"""
Resets the AgentBufferField
"""

self.last_brain_info = None
self.last_take_action_outputs = None
super(AgentBuffer, self).__init__()
super().__init__()
def reset_agent(self):
def reset_agent(self) -> None:
"""
Resets the AgentBuffer
"""

def __getitem__(self, key):
if key not in self.keys():
self[key] = self.AgentBufferField()
return super(AgentBuffer, self).__getitem__(key)
return super().__getitem__(key)
def check_length(self, key_list):
def check_length(self, key_list: List[str]) -> bool:
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list

length = len(self[key])
return True
def shuffle(self, sequence_length, key_list=None):
def shuffle(self, sequence_length: int, key_list: List[str] = None) -> None:
"""
Shuffles the fields in key_list in a consistent way: The reordering will
be the same across fields.

s = np.arange(len(self[key_list[0]]) // sequence_length)
np.random.shuffle(s)
for key in key_list:
tmp = []
tmp: List[np.ndarray] = []
def make_mini_batch(self, start, end):
def make_mini_batch(self, start: int, end: int) -> "AgentBuffer":
"""
Creates a mini-batch from buffer.
:param start: Starting index of buffer.

mini_batch = {}
mini_batch = AgentBuffer()
def sample_mini_batch(self, batch_size, sequence_length=1):
def sample_mini_batch(
self, batch_size: int, sequence_length: int = 1
) -> "AgentBuffer":
"""
Creates a mini-batch from a random start and end.
:param batch_size: number of elements to withdraw.

mini_batch[key].extend(self[key][i : i + sequence_length])
return mini_batch
def save_to_file(self, file_object):
def save_to_file(self, file_object: BinaryIO) -> None:
"""
Saves the AgentBuffer to a file-like object.
"""

def load_from_file(self, file_object):
def load_from_file(self, file_object: BinaryIO) -> None:
"""
Loads the AgentBuffer from a file-like object.
"""

# extend() will convert the numpy array's first dimension into list
self[key].extend(read_file[key][()])
def truncate(self, max_length, sequence_length=1):
def truncate(self, max_length: int, sequence_length: int = 1) -> None:
"""
Truncates the buffer to a certain length.

def __getitem__(self, key):
if key not in self.keys():
self[key] = AgentBuffer()
return super(AgentProcessorBuffer, self).__getitem__(key)
return super().__getitem__(key)
def reset_local_buffers(self):
def reset_local_buffers(self) -> None:
"""
Resets all the local local_buffers
"""

def append_update_buffer(
self,
update_buffer,
agent_id,
key_list=None,
batch_size=None,
training_length=None,
):
update_buffer: AgentBuffer,
agent_id: str,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended

)
def append_all_agent_batch_to_update_buffer(
self, update_buffer, key_list=None, batch_size=None, training_length=None
):
self,
update_buffer: AgentBuffer,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.

正在加载...
取消
保存