浏览代码

Split buffer into two buffers (PPO works)

/develop-newnormalization
Ervin Teng 5 年前
当前提交
df5ee7bf
共有 11 个文件被更改,包括 318 次插入320 次删除
  1. 16
      ml-agents/mlagents/trainers/bc/trainer.py
  2. 439
      ml-agents/mlagents/trainers/buffer.py
  3. 10
      ml-agents/mlagents/trainers/components/bc/module.py
  4. 5
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  5. 13
      ml-agents/mlagents/trainers/demo_loader.py
  6. 53
      ml-agents/mlagents/trainers/ppo/trainer.py
  7. 41
      ml-agents/mlagents/trainers/rl_trainer.py
  8. 1
      ml-agents/mlagents/trainers/sac/trainer.py
  9. 49
      ml-agents/mlagents/trainers/tests/test_buffer.py
  10. 10
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  11. 1
      ml-agents/MANIFEST.in

16
ml-agents/mlagents/trainers/bc/trainer.py


from mlagents.envs.brain import BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.bc.policy import BCPolicy
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.buffer import AgentBuffer, AgentProcessorBuffer
from mlagents.trainers.trainer import Trainer
logger = logging.getLogger("mlagents.trainers")

self.batches_per_epoch = trainer_parameters["batches_per_epoch"]
self.demonstration_buffer = Buffer()
self.evaluation_buffer = Buffer()
self.demonstration_buffer = AgentBuffer()
self.evaluation_buffer = AgentProcessorBuffer()
def add_experiences(
self,

Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not update_model() can be run
"""
return (
len(self.demonstration_buffer.update_buffer["actions"]) > self.n_sequences
)
return len(self.demonstration_buffer["actions"]) > self.n_sequences
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
self.demonstration_buffer.shuffle(self.policy.sequence_length)
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
len(self.demonstration_buffer["actions"]) // self.n_sequences,
self.batches_per_epoch,
)

update_buffer = self.demonstration_buffer.update_buffer
update_buffer = self.demonstration_buffer
mini_batch = update_buffer.make_mini_batch(i, i + batch_size)
run_out = self.policy.update(mini_batch, self.n_sequences)
loss = run_out["policy_loss"]

439
ml-agents/mlagents/trainers/buffer.py


pass
class Buffer(dict):
class AgentBuffer(dict):
Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
The keys correspond to the name of the field. Example: state, action
class AgentBuffer(dict):
class AgentBufferField(list):
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
The keys correspond to the name of the field. Example: state, action
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
"""
def __init__(self):
self.padding_value = 0
super(Buffer.AgentBuffer.AgentBufferField, self).__init__()
def __str__(self):
return str(np.array(self).shape)
def append(self, element, padding_value=0):
"""
Adds an element to this list. Also lets you change the padding
type, so that it can be set on append (e.g. action_masks should
be padded with 1.)
:param element: The element to append to the list.
:param padding_value: The value used to pad when get_batch is called.
"""
super(Buffer.AgentBuffer.AgentBufferField, self).append(element)
self.padding_value = padding_value
def extend(self, data):
"""
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
self += list(np.array(data))
def set(self, data):
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.
"""
self[:] = []
self[:] = list(np.array(data))
def get_batch(self, batch_size=None, training_length=1, sequential=True):
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If
None: only takes one element.
:param sequential: If true and training_length is not None: the elements
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
"""
if sequential:
# The sequences will not have overlapping elements (this involves padding)
leftover = len(self) % training_length
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) // training_length + 1 * (leftover != 0)
# The maximum number of sequences taken from a list of length len(self) without overlapping
# with padding is equal to batch_size
if batch_size > (
len(self) // training_length + 1 * (leftover != 0)
):
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
if batch_size * training_length > len(self):
padding = np.array(self[-1]) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:],
dtype=np.float32,
)
else:
return np.array(
self[len(self) - batch_size * training_length :],
dtype=np.float32,
)
else:
# The sequences will have overlapping elements
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) - training_length + 1
# The number of sequences of length training_length taken from a list of len(self) elements
# with overlapping is equal to batch_size
if (len(self) - training_length + 1) < batch_size:
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
tmp_list = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
def reset_field(self):
"""
Resets the AgentBufferField
"""
self[:] = []
self.last_brain_info = None
self.last_take_action_outputs = None
super(Buffer.AgentBuffer, self).__init__()
self.padding_value = 0
super(AgentBuffer.AgentBufferField, self).__init__()
return ", ".join(
["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()]
)
return str(np.array(self).shape)
def reset_agent(self):
def append(self, element, padding_value=0):
Resets the AgentBuffer
Adds an element to this list. Also lets you change the padding
type, so that it can be set on append (e.g. action_masks should
be padded with 1.)
:param element: The element to append to the list.
:param padding_value: The value used to pad when get_batch is called.
for k in self.keys():
self[k].reset_field()
self.last_brain_info = None
self.last_take_action_outputs = None
super(AgentBuffer.AgentBufferField, self).append(element)
self.padding_value = padding_value
def __getitem__(self, key):
if key not in self.keys():
self[key] = self.AgentBufferField()
return super(Buffer.AgentBuffer, self).__getitem__(key)
def check_length(self, key_list):
def extend(self, data):
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
if len(key_list) < 2:
return True
length = None
for key in key_list:
if key not in self.keys():
return False
if (length is not None) and (length != len(self[key])):
return False
length = len(self[key])
return True
self += list(np.array(data))
def shuffle(self, sequence_length, key_list=None):
def set(self, data):
Shuffles the fields in key_list in a consistent way: The reordering will
be the same across fields.
:param key_list: The fields that must be shuffled.
Sets the list of np.array to the input data
:param data: The np.array list to be set.
if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):
raise BufferException(
"Unable to shuffle if the fields are not of same length"
)
s = np.arange(len(self[key_list[0]]) // sequence_length)
np.random.shuffle(s)
for key in key_list:
tmp = []
for i in s:
tmp += self[key][i * sequence_length : (i + 1) * sequence_length]
self[key][:] = tmp
self[:] = []
self[:] = list(np.array(data))
def make_mini_batch(self, start, end):
def get_batch(self, batch_size=None, training_length=1, sequential=True):
Creates a mini-batch from buffer.
:param start: Starting index of buffer.
:param end: Ending index of buffer.
:return: Dict of mini batch.
Retrieve the last batch_size elements of length training_length
from the list of np.array
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If
None: only takes one element.
:param sequential: If true and training_length is not None: the elements
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
mini_batch = {}
for key in self:
mini_batch[key] = self[key][start:end]
return mini_batch
def sample_mini_batch(self, batch_size, sequence_length=1):
"""
Creates a mini-batch from a random start and end.
:param batch_size: number of elements to withdraw.
:param sequence_length: Length of sequences to sample.
Number of sequences to sample will be batch_size/sequence_length.
"""
num_seq_to_sample = batch_size // sequence_length
mini_batch = Buffer.AgentBuffer()
buff_len = len(next(iter(self.values())))
num_sequences_in_buffer = buff_len // sequence_length
start_idxes = (
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
* sequence_length
) # Sample random sequence starts
for i in start_idxes:
for key in self:
mini_batch[key].extend(self[key][i : i + sequence_length])
return mini_batch
def save_to_file(self, file_object):
"""
Saves the AgentBuffer to a file-like object.
"""
with h5py.File(file_object) as write_file:
for key, data in self.items():
write_file.create_dataset(
key, data=data, dtype="f", compression="gzip"
if sequential:
# The sequences will not have overlapping elements (this involves padding)
leftover = len(self) % training_length
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) // training_length + 1 * (leftover != 0)
# The maximum number of sequences taken from a list of length len(self) without overlapping
# with padding is equal to batch_size
if batch_size > (len(self) // training_length + 1 * (leftover != 0)):
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
if batch_size * training_length > len(self):
padding = np.array(self[-1]) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:],
dtype=np.float32,
)
else:
return np.array(
self[len(self) - batch_size * training_length :],
dtype=np.float32,
)
else:
# The sequences will have overlapping elements
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) - training_length + 1
# The number of sequences of length training_length taken from a list of len(self) elements
# with overlapping is equal to batch_size
if (len(self) - training_length + 1) < batch_size:
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
tmp_list = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
def load_from_file(self, file_object):
def reset_field(self):
Loads the AgentBuffer from a file-like object.
Resets the AgentBufferField
with h5py.File(file_object) as read_file:
for key in list(read_file.keys()):
self[key] = Buffer.AgentBuffer.AgentBufferField()
# extend() will convert the numpy array's first dimension into list
self[key].extend(read_file[key][()])
self[:] = []
self.update_buffer = self.AgentBuffer()
super(Buffer, self).__init__()
self.last_brain_info = None
self.last_take_action_outputs = None
super(AgentBuffer, self).__init__()
return "update buffer :\n\t{0}\nlocal_buffers :\n{1}".format(
str(self.update_buffer),
"\n".join(
["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()]
),
)
return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()])
def reset_agent(self):
"""
Resets the AgentBuffer
"""
for k in self.keys():
self[k].reset_field()
self.last_brain_info = None
self.last_take_action_outputs = None
self[key] = self.AgentBuffer()
return super(Buffer, self).__getitem__(key)
self[key] = self.AgentBufferField()
return super(AgentBuffer, self).__getitem__(key)
def check_length(self, key_list):
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
"""
if len(key_list) < 2:
return True
length = None
for key in key_list:
if key not in self.keys():
return False
if (length is not None) and (length != len(self[key])):
return False
length = len(self[key])
return True
def shuffle(self, sequence_length, key_list=None):
"""
Shuffles the fields in key_list in a consistent way: The reordering will
be the same across fields.
:param key_list: The fields that must be shuffled.
"""
if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):
raise BufferException(
"Unable to shuffle if the fields are not of same length"
)
s = np.arange(len(self[key_list[0]]) // sequence_length)
np.random.shuffle(s)
for key in key_list:
tmp = []
for i in s:
tmp += self[key][i * sequence_length : (i + 1) * sequence_length]
self[key][:] = tmp
def make_mini_batch(self, start, end):
"""
Creates a mini-batch from buffer.
:param start: Starting index of buffer.
:param end: Ending index of buffer.
:return: Dict of mini batch.
"""
mini_batch = {}
for key in self:
mini_batch[key] = self[key][start:end]
return mini_batch
def sample_mini_batch(self, batch_size, sequence_length=1):
"""
Creates a mini-batch from a random start and end.
:param batch_size: number of elements to withdraw.
:param sequence_length: Length of sequences to sample.
Number of sequences to sample will be batch_size/sequence_length.
"""
num_seq_to_sample = batch_size // sequence_length
mini_batch = AgentBuffer()
buff_len = len(next(iter(self.values())))
num_sequences_in_buffer = buff_len // sequence_length
start_idxes = (
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
* sequence_length
) # Sample random sequence starts
for i in start_idxes:
for key in self:
mini_batch[key].extend(self[key][i : i + sequence_length])
return mini_batch
def save_to_file(self, file_object):
"""
Saves the AgentBuffer to a file-like object.
"""
with h5py.File(file_object) as write_file:
for key, data in self.items():
write_file.create_dataset(key, data=data, dtype="f", compression="gzip")
def reset_update_buffer(self):
def load_from_file(self, file_object):
Resets the update buffer
Loads the AgentBuffer from a file-like object.
self.update_buffer.reset_agent()
with h5py.File(file_object) as read_file:
for key in list(read_file.keys()):
self[key] = AgentBuffer.AgentBufferField()
# extend() will convert the numpy array's first dimension into list
self[key].extend(read_file[key][()])
def truncate_update_buffer(self, max_length, sequence_length=1):
def truncate(self, max_length, sequence_length=1):
Truncates the update buffer to a certain length.
Truncates the buffer to a certain length.
current_length = len(next(iter(self.update_buffer.values())))
current_length = len(next(iter(self)))
for _key in self.update_buffer.keys():
self.update_buffer[_key] = self.update_buffer[_key][
current_length - max_length :
]
for _key in self.keys():
self[_key] = self[_key][current_length - max_length :]
class AgentProcessorBuffer(dict):
"""
AgentProcessorBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
def __str__(self):
return "local_buffers :\n{0}".format(
"\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()])
)
def __getitem__(self, key):
if key not in self.keys():
self[key] = AgentBuffer()
return super(AgentProcessorBuffer, self).__getitem__(key)
def reset_local_buffers(self):
"""

self[k].reset_agent()
def append_update_buffer(
self, agent_id, key_list=None, batch_size=None, training_length=None
self,
update_buffer,
agent_id,
key_list=None,
batch_size=None,
training_length=None,
):
"""
Appends the buffer of an agent to the update buffer.

)
)
for field_key in key_list:
self.update_buffer[field_key].extend(
update_buffer[field_key].extend(
self[agent_id][field_key].get_batch(
batch_size=batch_size, training_length=training_length
)

self, key_list=None, batch_size=None, training_length=None
self, update_buffer, key_list=None, batch_size=None, training_length=None
):
"""
Appends the buffer of all agents to the update buffer.

"""
for agent_id in self.keys():
self.append_update_buffer(agent_id, key_list, batch_size, training_length)
self.append_update_buffer(
update_buffer, agent_id, key_list, batch_size, training_length
)

10
ml-agents/mlagents/trainers/components/bc/module.py


self.batch_size = batch_size if batch_size else default_batch_size
self.num_epoch = num_epoch if num_epoch else default_num_epoch
self.n_sequences = max(
min(
self.batch_size, len(self.demonstration_buffer.update_buffer["actions"])
)
min(self.batch_size, len(self.demonstration_buffer["actions"]))
// policy.sequence_length,
1,
)

batch_losses = []
possible_demo_batches = (
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences
len(self.demonstration_buffer["actions"]) // self.n_sequences
)
possible_batches = possible_demo_batches

for _ in range(n_epoch):
self.demonstration_buffer.update_buffer.shuffle(
self.demonstration_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
if max_batches == 0:

for i in range(num_batches // self.policy.sequence_length):
demo_update_buffer = self.demonstration_buffer.update_buffer
demo_update_buffer = self.demonstration_buffer
start = i * self.n_sequences * self.policy.sequence_length
end = (i + 1) * self.n_sequences * self.policy.sequence_length
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)

5
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


:return: Feed_dict for update process.
"""
max_num_experiences = min(
len(mini_batch["actions"]),
len(self.demonstration_buffer.update_buffer["actions"]),
len(mini_batch["actions"]), len(self.demonstration_buffer["actions"])
)
# If num_sequences is less, we need to shorten the input batch.
for key, element in mini_batch.items():

mini_batch_demo = self.demonstration_buffer.update_buffer.sample_mini_batch(
mini_batch_demo = self.demonstration_buffer.sample_mini_batch(
len(mini_batch["actions"]), 1
)

13
ml-agents/mlagents/trainers/demo_loader.py


import os
from typing import List, Tuple
import numpy as np
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.buffer import AgentBuffer, AgentProcessorBuffer
from mlagents.envs.brain import BrainParameters, BrainInfo
from mlagents.envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,

pair_infos: List[AgentInfoActionPairProto],
brain_params: BrainParameters,
sequence_length: int,
) -> Buffer:
) -> AgentBuffer:
demo_buffer = Buffer()
demo_buffer = AgentProcessorBuffer()
update_buffer = AgentBuffer()
for idx, experience in enumerate(pair_infos):
if idx > len(pair_infos) - 2:
break

demo_buffer[0]["prev_action"].append(previous_action)
if next_brain_info.local_done[0]:
demo_buffer.append_update_buffer(
0, batch_size=None, training_length=sequence_length
update_buffer, 0, batch_size=None, training_length=sequence_length
0, batch_size=None, training_length=sequence_length
update_buffer, 0, batch_size=None, training_length=sequence_length
)
return demo_buffer

) -> Tuple[BrainParameters, Buffer]:
) -> Tuple[BrainParameters, AgentBuffer]:
"""
Loads demonstration file and uses it to fill training buffer.
:param file_path: Location of demonstration file (.demo).

53
ml-agents/mlagents/trainers/ppo/trainer.py


if self.is_training:
self.policy.update_normalization(next_info.vector_observations)
for l in range(len(next_info.agents)):
agent_actions = self.training_buffer[next_info.agents[l]]["actions"]
agent_actions = self.processing_buffer[next_info.agents[l]]["actions"]
if (
next_info.local_done[l]
or len(agent_actions) > self.trainer_parameters["time_horizon"]

bootstrapping_info = self.training_buffer[agent_id].last_brain_info
bootstrapping_info = self.processing_buffer[
agent_id
].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
bootstrapping_info = next_info

for name in self.policy.reward_signals:
bootstrap_value = value_next[name]
local_rewards = self.training_buffer[agent_id][
local_rewards = self.processing_buffer[agent_id][
local_value_estimates = self.training_buffer[agent_id][
local_value_estimates = self.processing_buffer[agent_id][
"{}_value_estimates".format(name)
].get_batch()
local_advantage = get_gae(

)
local_return = local_advantage + local_value_estimates
# This is later use as target for the different value estimates
self.training_buffer[agent_id]["{}_returns".format(name)].set(
self.processing_buffer[agent_id]["{}_returns".format(name)].set(
self.training_buffer[agent_id]["{}_advantage".format(name)].set(
self.processing_buffer[agent_id]["{}_advantage".format(name)].set(
local_advantage
)
tmp_advantages.append(local_advantage)

global_returns = list(np.mean(np.array(tmp_returns), axis=0))
self.training_buffer[agent_id]["advantages"].set(global_advantages)
self.training_buffer[agent_id]["discounted_returns"].set(global_returns)
self.processing_buffer[agent_id]["advantages"].set(global_advantages)
self.processing_buffer[agent_id]["discounted_returns"].set(
global_returns
)
self.training_buffer.append_update_buffer(
self.processing_buffer.append_update_buffer(
self.update_buffer,
self.training_buffer[agent_id].reset_agent()
self.processing_buffer[agent_id].reset_agent()
if next_info.local_done[l]:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)

actions = take_action_outputs["action"]
if self.policy.use_continuous_act:
actions_pre = take_action_outputs["pre_action"]
self.training_buffer[agent_id]["actions_pre"].append(actions_pre[agent_idx])
self.processing_buffer[agent_id]["actions_pre"].append(
actions_pre[agent_idx]
)
self.training_buffer[agent_id]["random_normal_epsilon"].append(
self.processing_buffer[agent_id]["random_normal_epsilon"].append(
self.training_buffer[agent_id]["actions"].append(actions[agent_idx])
self.training_buffer[agent_id]["action_probs"].append(a_dist[agent_idx])
self.processing_buffer[agent_id]["actions"].append(actions[agent_idx])
self.processing_buffer[agent_id]["action_probs"].append(a_dist[agent_idx])
def add_rewards_outputs(
self,

"""
for name, reward_result in rewards_out.reward_signals.items():
# 0 because we use the scaled reward to train the agent
self.training_buffer[agent_id]["{}_rewards".format(name)].append(
self.processing_buffer[agent_id]["{}_rewards".format(name)].append(
self.training_buffer[agent_id]["{}_value_estimates".format(name)].append(
self.processing_buffer[agent_id]["{}_value_estimates".format(name)].append(
values[name][agent_idx][0]
)

:return: A boolean corresponding to whether or not update_model() can be run
"""
size_of_buffer = len(self.training_buffer.update_buffer["actions"])
size_of_buffer = len(self.update_buffer["actions"])
return size_of_buffer > self.trainer_parameters["buffer_size"]
def update_policy(self):

"""
buffer_length = len(self.training_buffer.update_buffer["actions"])
buffer_length = len(self.update_buffer["actions"])
self.trainer_metrics.start_policy_update_timer(
number_experiences=buffer_length,
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),

int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
advantages = self.training_buffer.update_buffer["advantages"].get_batch()
self.training_buffer.update_buffer["advantages"].set(
advantages = self.update_buffer["advantages"].get_batch()
self.update_buffer["advantages"].set(
self.training_buffer.update_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
buffer = self.training_buffer.update_buffer
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for l in range(0, max_num_batch * batch_size, batch_size):
update_stats = self.policy.update(

41
ml-agents/mlagents/trainers/rl_trainer.py


from mlagents.envs.brain import BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.buffer import AgentBuffer, AgentProcessorBuffer
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.components.reward_signals import RewardSignalResult

# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.collected_rewards = {"environment": {}}
self.training_buffer = Buffer()
self.processing_buffer = AgentProcessorBuffer()
self.update_buffer = AgentBuffer()
self.episode_steps = {}
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:

agents = []
action_masks = []
for agent_id in next_info.agents:
agent_brain_info = self.training_buffer[agent_id].last_brain_info
agent_brain_info = self.processing_buffer[agent_id].last_brain_info
if agent_brain_info is None:
agent_brain_info = next_info
agent_index = agent_brain_info.agents.index(agent_id)

)
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info
self.training_buffer[
self.processing_buffer[agent_id].last_brain_info = curr_info
self.processing_buffer[
agent_id
].last_take_action_outputs = take_action_outputs

)
for agent_id in next_info.agents:
stored_info = self.training_buffer[agent_id].last_brain_info
stored_take_action_outputs = self.training_buffer[
stored_info = self.processing_buffer[agent_id].last_brain_info
stored_take_action_outputs = self.processing_buffer[
agent_id
].last_take_action_outputs
if stored_info is not None:

for i, _ in enumerate(stored_info.visual_observations):
self.training_buffer[agent_id]["visual_obs%d" % i].append(
self.processing_buffer[agent_id]["visual_obs%d" % i].append(
self.training_buffer[agent_id]["next_visual_obs%d" % i].append(
next_info.visual_observations[i][next_idx]
)
self.processing_buffer[agent_id][
"next_visual_obs%d" % i
].append(next_info.visual_observations[i][next_idx])
self.training_buffer[agent_id]["vector_obs"].append(
self.processing_buffer[agent_id]["vector_obs"].append(
self.training_buffer[agent_id]["next_vector_in"].append(
self.processing_buffer[agent_id]["next_vector_in"].append(
self.training_buffer[agent_id]["memory"].append(
self.processing_buffer[agent_id]["memory"].append(
self.training_buffer[agent_id]["masks"].append(1.0)
self.training_buffer[agent_id]["done"].append(
self.processing_buffer[agent_id]["masks"].append(1.0)
self.processing_buffer[agent_id]["done"].append(
next_info.local_done[next_idx]
)
# Add the outputs of the last eval

self.training_buffer[agent_id]["action_mask"].append(
self.processing_buffer[agent_id]["action_mask"].append(
self.training_buffer[agent_id]["prev_action"].append(
self.processing_buffer[agent_id]["prev_action"].append(
self.policy.retrieve_previous_action([agent_id])[0, :]
)

A signal that the Episode has ended. The buffer must be reset.
Get only called when the academy resets.
"""
self.training_buffer.reset_local_buffers()
self.processing_buffer.reset_local_buffers()
for agent_id in self.episode_steps:
self.episode_steps[agent_id] = 0
for rewards in self.collected_rewards.values():

Clear the buffers that have been built up during inference. If
we're not training, this should be called instead of update_policy.
"""
self.training_buffer.reset_update_buffer()
self.update_buffer.reset_agent()
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int

1
ml-agents/mlagents/trainers/sac/trainer.py


self.training_buffer[agent_id]["done"][-1] = False
self.training_buffer.append_update_buffer(
self.update_buffer,
agent_id,
batch_size=None,
training_length=self.policy.sequence_length,

49
ml-agents/mlagents/trainers/tests/test_buffer.py


import numpy as np
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.buffer import AgentProcessorBuffer, AgentBuffer
def assert_array(a, b):

def construct_fake_buffer():
b = Buffer()
b = AgentProcessorBuffer()
b[fake_agent_id]["vector_observation"].append(
b["vector_observation"].append(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,

b[fake_agent_id]["action"].append(
b["action"].append(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,

)
b[4].reset_agent()
assert len(b[4]) == 0
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
assert len(b.update_buffer["action"]) == 20
assert np.array(b.update_buffer["action"]).shape == (20, 2)
update_buffer = AgentBuffer()
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)
assert len(update_buffer["action"]) == 20
assert np.array(update_buffer["action"]).shape == (20, 2)
c = b.update_buffer.make_mini_batch(start=0, end=1)
assert c.keys() == b.update_buffer.keys()
c = update_buffer.make_mini_batch(start=0, end=1)
assert c.keys() == update_buffer.keys()
assert np.array(c["action"]).shape == (1, 2)

def test_buffer_sample():
b = construct_fake_buffer()
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
update_buffer = AgentBuffer()
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)
mb = b.update_buffer.sample_mini_batch(batch_size=4, sequence_length=1)
assert mb.keys() == b.update_buffer.keys()
mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1)
assert mb.keys() == update_buffer.keys()
mb = b.update_buffer.sample_mini_batch(batch_size=20, sequence_length=19)
assert mb.keys() == b.update_buffer.keys()
mb = update_buffer.sample_mini_batch(batch_size=20, sequence_length=19)
assert mb.keys() == update_buffer.keys()
# Should only return one sequence
assert np.array(mb["action"]).shape == (19, 2)

b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
update_buffer = AgentBuffer()
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)
b.truncate_update_buffer(2)
update_buffer.truncate(2)
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)
b.truncate_update_buffer(4, sequence_length=3)
assert len(b.update_buffer["action"]) == 3
update_buffer.truncate(4, sequence_length=3)
assert len(update_buffer["action"]) == 3

10
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


import numpy as np
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.buffer import AgentBuffer
@pytest.fixture

def test_clear_update_buffer():
trainer = create_rl_trainer()
trainer.training_buffer = construct_fake_buffer()
trainer.training_buffer.append_update_buffer(2, batch_size=None, training_length=2)
trainer.processing_buffer = construct_fake_buffer()
trainer.update_buffer = AgentBuffer()
trainer.training_buffer.append_update_buffer(
trainer.update_buffer, 2, batch_size=None, training_length=2
)
for _, arr in trainer.training_buffer.update_buffer.items():
for _, arr in trainer.update_buffer.items():
assert len(arr) == 0

1
ml-agents/MANIFEST.in


include ../VERSION
正在加载...
取消
保存