浏览代码

Non-working commit

/develop-newnormalization
Ervin Teng 5 年前
当前提交
3434352a
共有 3 个文件被更改,包括 96 次插入74 次删除
  1. 73
      ml-agents/mlagents/trainers/buffer.py
  2. 8
      ml-agents/mlagents/trainers/trainer_controller.py
  3. 89
      ml-agents/mlagents/trainers/agent_processor.py

73
ml-agents/mlagents/trainers/buffer.py


if current_length > max_length:
for _key in self.keys():
self[_key] = self[_key][current_length - max_length :]
class AgentProcessorBuffer(dict):
"""
AgentProcessorBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
def __str__(self):
return "local_buffers :\n{0}".format(
"\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()])
)
def __getitem__(self, key):
if key not in self.keys():
self[key] = AgentBuffer()
return super().__getitem__(key)
def reset_local_buffers(self) -> None:
"""
Resets all the local local_buffers
"""
agent_ids = list(self.keys())
for k in agent_ids:
self[k].reset_agent()
def append_to_update_buffer(
self,
update_buffer: AgentBuffer,
agent_id: str,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = self[agent_id].keys()
if not self[agent_id].check_length(key_list):
raise BufferException(
"The length of the fields {0} for agent {1} where not of same length".format(
key_list, agent_id
)
)
for field_key in key_list:
update_buffer[field_key].extend(
self[agent_id][field_key].get_batch(
batch_size=batch_size, training_length=training_length
)
)
def append_all_agent_batch_to_update_buffer(
self,
update_buffer: AgentBuffer,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
for agent_id in self.keys():
self.append_to_update_buffer(
update_buffer, agent_id, key_list, batch_size, training_length
)

8
ml-agents/mlagents/trainers/trainer_controller.py


import os
import json
import logging
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, NamedTuple
import numpy as np
from mlagents.tf_utils import tf

from mlagents.trainers.trainer import Trainer, TrainerMetrics
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.agent_processor import AgentProcessor
class AgentManager(NamedTuple):
processor: AgentProcessor
class TrainerController(object):

:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
"""
self.trainers: Dict[str, Trainer] = {}
self.managers: Dict[str, AgentManager] = {}
self.trainer_factory = trainer_factory
self.model_path = model_path
self.summaries_dir = summaries_dir

89
ml-agents/mlagents/trainers/agent_processor.py


from typing import List
from collections import defaultdict
from mlagents.trainers.buffer import AgentBuffer
from mlagents.envs.exception import UnityException
class AgentProcessorException(UnityException):
"""
Related to errors with the AgentProcessor.
"""
pass
class AgentProcessor:
"""
AgentProcessor contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
def __init__(self):
self.agent_buffers = defaultdict(AgentBuffer)
def __str__(self):
return "local_buffers :\n{0}".format(
"\n".join(
[
"\tagent {0} :{1}".format(k, str(self.agent_buffers[k]))
for k in self.agent_buffers.keys()
]
)
)
def reset_local_buffers(self) -> None:
"""
Resets all the local local_buffers
"""
agent_ids = list(self.agent_buffers.keys())
for k in agent_ids:
self.agent_buffers[k].reset_agent()
def append_to_update_buffer(
self,
update_buffer: AgentBuffer,
agent_id: str,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = self.agent_buffers[agent_id].keys()
if not self.agent_buffers[agent_id].check_length(key_list):
raise AgentProcessorException(
"The length of the fields {0} for agent {1} where not of same length".format(
key_list, agent_id
)
)
for field_key in key_list:
update_buffer[field_key].extend(
self.agent_buffers[agent_id][field_key].get_batch(
batch_size=batch_size, training_length=training_length
)
)
def append_all_agent_batch_to_update_buffer(
self,
update_buffer: AgentBuffer,
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
for agent_id in self.agent_buffers.keys():
self.append_to_update_buffer(
update_buffer, agent_id, key_list, batch_size, training_length
)
正在加载...
取消
保存