Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

75 行
2.8 KiB

from typing import List
from mlagents.trainers.buffer import AgentBuffer, BufferException
class ProcessingBuffer(dict):
"""
ProcessingBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
"""
def __str__(self):
return "local_buffers :\n{0}".format(
"\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()])
)
def __getitem__(self, key):
if key not in self.keys():
self[key] = AgentBuffer()
return super().__getitem__(key)
def reset_local_buffers(self) -> None:
"""
Resets all the local AgentBuffers.
"""
for buf in self.values():
buf.reset_agent()
def append_to_update_buffer(
self,
update_buffer: AgentBuffer,
agent_id: str,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of an agent to the update buffer.
:param update_buffer: A reference to an AgentBuffer to append the agent's buffer to
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = self[agent_id].keys()
if not self[agent_id].check_length(key_list):
raise BufferException(
"The length of the fields {0} for agent {1} were not of same length".format(
key_list, agent_id
)
)
for field_key in key_list:
update_buffer[field_key].extend(
self[agent_id][field_key].get_batch(
batch_size=batch_size, training_length=training_length
)
)
def append_all_agent_batch_to_update_buffer(
self,
update_buffer: AgentBuffer,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
for agent_id in self.keys():
self.append_to_update_buffer(
update_buffer, agent_id, key_list, batch_size, training_length
)