比较提交

...
此合并请求有变更与目标分支冲突。
/ml-agents/mlagents/trainers/agent_processor.py
/ml-agents/mlagents/trainers/behavior_id_utils.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/tests/mock_brain.py
/ml-agents/mlagents/trainers/tests/test_agent_processor.py
/ml-agents/mlagents/trainers/tests/test_buffer.py
/ml-agents/mlagents/trainers/buffer.py
/ml-agents/mlagents/trainers/torch/agent_action.py
/ml-agents/mlagents/trainers/torch/utils.py
/ml-agents/mlagents/trainers/trajectory.py
/ml-agents/mlagents/trainers/tests/torch/test_agent_action.py

3 次代码提交

共有 15 个文件被更改,包括 758 次插入96 次删除
  1. 0
      config/ppo/PushBlock.yaml
  2. 9
      ml-agents/mlagents/trainers/behavior_id_utils.py
  3. 137
      ml-agents/mlagents/trainers/agent_processor.py
  4. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 106
      ml-agents/mlagents/trainers/buffer.py
  6. 163
      ml-agents/mlagents/trainers/trajectory.py
  7. 80
      ml-agents/mlagents/trainers/tests/test_buffer.py
  8. 100
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  9. 20
      ml-agents/mlagents/trainers/tests/mock_brain.py
  10. 54
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  11. 12
      ml-agents/mlagents/trainers/torch/utils.py
  12. 86
      ml-agents/mlagents/trainers/torch/agent_action.py
  13. 11
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta
  14. 63
      ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
  15. 11
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta

0
config/ppo/PushBlock.yaml

9
ml-agents/mlagents/trainers/behavior_id_utils.py


"""
Create an agent id that is unique across environment workers using the worker_id.
"""
return f"${worker_id}-{agent_id}"
return f"agent_{worker_id}-{agent_id}"
def get_global_group_id(worker_id: int, group_id: int) -> str:
"""
Create a group id that is unique across environment workers when using the worker_id.
"""
return f"group_{worker_id}-{group_id}"

137
ml-agents/mlagents/trainers/agent_processor.py


import sys
import numpy as np
from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union
from collections import defaultdict, Counter
import queue

StatsAggregationMethod,
EnvironmentStats,
)
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.trajectory import GroupmateStatus, Trajectory, AgentExperience
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id
T = TypeVar("T")

"""
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {}
# current_group_obs is used to collect the current, most recently seen
# obs of all the agents in the same group, and assemble the group obs.
self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict(
lambda: defaultdict(list)
)
# group_status is used to collect the current, most recently seen
# group status of all the agents in the same group, and assemble the group obs.
self.group_status: Dict[str, Dict[str, GroupmateStatus]] = defaultdict(
lambda: defaultdict(None)
)
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {}

if global_id in self.last_step_result: # Don't store if agent just reset
self.last_take_action_outputs[global_id] = take_action_outputs
# Iterate over all the terminal steps
# Iterate over all the terminal steps, first gather all the group obs
# and then create the AgentExperiences/Trajectories. _add_to_group_status
# stores Group statuses in a common data structure self.group_status
for terminal_step in terminal_steps.values():
self._add_to_group_status(terminal_step, worker_id)
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]
# Iterate over all the decision steps
# Clear the last seen group obs when agents die.
self._clear_group_obs(global_id)
# Iterate over all the decision steps, first gather all the group obs
# and then create the trajectories. _add_to_group_status
# stores Group statuses in a common data structure self.group_status
for ongoing_step in decision_steps.values():
self._add_to_group_status(ongoing_step, worker_id)
global_id = get_global_agent_id(worker_id, local_id)
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]
)
for _gid in action_global_agent_ids:

[_gid], take_action_outputs["action"]
)
def _add_to_group_status(
self, step: Union[TerminalStep, DecisionStep], worker_id: int
) -> None:
"""
Takes a TerminalStep or DecisionStep and adds the information in it
to self.group_status. This information can then be retrieved
when constructing trajectories to get the status of group mates.
:param step: TerminalStep or DecisionStep
:param worker_id: Worker ID of this particular environment. Used to generate a
global group id.
"""
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
stored_decision_step, idx = self.last_step_result.get(
global_agent_id, (None, None)
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_agent_id, None
)
if stored_decision_step is not None and stored_take_action_outputs is not None:
# 0, the default group_id, means that the agent doesn't belong to an agent group.
# If 0, don't add any groupmate information.
if step.group_id > 0:
global_group_id = get_global_group_id(worker_id, step.group_id)
stored_actions = stored_take_action_outputs["action"]
action_tuple = ActionTuple(
continuous=stored_actions.continuous[idx],
discrete=stored_actions.discrete[idx],
)
group_status = GroupmateStatus(
obs=stored_decision_step.obs,
reward=step.reward,
action=action_tuple,
done=isinstance(step, TerminalStep),
)
self.group_status[global_group_id][global_agent_id] = group_status
self.current_group_obs[global_group_id][global_agent_id] = step.obs
def _clear_group_obs(self, global_id: str) -> None:
self._delete_in_nested_dict(self.current_group_obs, global_id)
self._delete_in_nested_dict(self.group_status, global_id)
def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None:
for _manager_id in list(nested_dict.keys()):
_team_group = nested_dict[_manager_id]
self._safe_delete(_team_group, key)
if not _team_group: # if dict is empty
self._safe_delete(nested_dict, _manager_id)
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
global_group_id = get_global_group_id(worker_id, step.group_id)
stored_decision_step, idx = self.last_step_result.get(
global_agent_id, (None, None)
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_agent_id, None
)
self.last_step_result[global_id] = (step, index)
self.last_step_result[global_agent_id] = (step, index)
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
memory = self.policy.retrieve_previous_memories([global_agent_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step

discrete=stored_action_probs.discrete[idx],
)
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]
# Assemble teammate_obs. If none saved, then it will be an empty list.
group_statuses = []
for _id, _mate_status in self.group_status[global_group_id].items():
if _id != global_agent_id:
group_statuses.append(_mate_status)
experience = AgentExperience(
obs=obs,
reward=step.reward,

prev_action=prev_action,
interrupted=interrupted,
memory=memory,
group_status=group_statuses,
group_reward=step.group_reward,
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += step.reward
self.experience_buffers[global_agent_id].append(experience)
self.episode_rewards[global_agent_id] += step.reward
self.episode_steps[global_id] += 1
self.episode_steps[global_agent_id] += 1
len(self.experience_buffers[global_id]) >= self.max_trajectory_length
len(self.experience_buffers[global_agent_id])
>= self.max_trajectory_length
# Make next AgentExperience
next_group_obs = []
for _id, _obs in self.current_group_obs[global_group_id].items():
if _id != global_agent_id:
next_group_obs.append(_obs)
steps=self.experience_buffers[global_id],
agent_id=global_id,
steps=self.experience_buffers[global_agent_id],
agent_id=global_agent_id,
next_group_obs=next_group_obs,
self.experience_buffers[global_id] = []
self.experience_buffers[global_agent_id] = []
"Environment/Episode Length", self.episode_steps.get(global_id, 0)
"Environment/Episode Length",
self.episode_steps.get(global_agent_id, 0),
self._clean_agent_data(global_id)
self._clean_agent_data(global_agent_id)
def _clean_agent_data(self, global_id: str) -> None:
"""

2
ml-agents/mlagents/trainers/ppo/trainer.py


int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
advantages = np.array(self.update_buffer[BufferKey.ADVANTAGES].get_batch())
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)

106
ml-agents/mlagents/trainers/buffer.py


from mlagents_envs.exception import UnityException
# Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards,
# a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references.
BufferEntry = Union[np.ndarray, List[np.ndarray]]
class BufferException(UnityException):
"""

class BufferKey(enum.Enum):
ACTION_MASK = "action_mask"
CONTINUOUS_ACTION = "continuous_action"
NEXT_CONT_ACTION = "next_continuous_action"
NEXT_DISC_ACTION = "next_discrete_action"
DISCRETE_LOG_PROBS = "discrete_log_probs"
DONE = "done"
ENVIRONMENT_REWARDS = "environment_rewards"

ADVANTAGES = "advantages"
DISCOUNTED_RETURNS = "discounted_returns"
GROUP_DONES = "group_dones"
GROUPMATE_REWARDS = "groupmate_reward"
GROUP_REWARD = "group_reward"
GROUP_CONTINUOUS_ACTION = "group_continuous_action"
GROUP_DISCRETE_ACTION = "group_discrete_aaction"
GROUP_NEXT_CONT_ACTION = "group_next_cont_action"
GROUP_NEXT_DISC_ACTION = "group_next_disc_action"
class ObservationKeyPrefix(enum.Enum):

GROUP_OBSERVATION = "group_obs"
NEXT_GROUP_OBSERVATION = "next_group_obs"
class RewardSignalKeyPrefix(enum.Enum):
# Reward signals

class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to its
AgentBufferField with the append method.
AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries.
When an agent collects a field, you can add it to its AgentBufferField with the append method.
def __init__(self):
def __init__(self, *args, **kwargs):
super().__init__()
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return f"AgentBufferField: {super().__str__()}"
def __str__(self):
return str(np.array(self).shape)
def __getitem__(self, index):
return_data = super().__getitem__(index)
if isinstance(return_data, list):
return AgentBufferField(return_data)
else:
return return_data
def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
"""

super().append(element)
self.padding_value = padding_value
def extend(self, data: np.ndarray) -> None:
"""
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
def set(self, data: List[BufferEntry]) -> None:
self += list(np.array(data, dtype=np.float32))
def set(self, data):
Sets the list of BufferEntry to the input data
:param data: The BufferEntry list to be set.
Sets the list of np.array to the input data
:param data: The np.array list to be set.
"""
# Make sure we convert incoming data to float32 if it's a float
dtype = None
if data is not None and len(data) and isinstance(data[0], float):
dtype = np.float32
self[:] = list(np.array(data, dtype=dtype))
self[:] = data
def get_batch(
self,

) -> np.ndarray:
) -> List[BufferEntry]:
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array

)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:], dtype=np.float32
)
return [padding] * (training_length - leftover) + self[:]
return np.array(
self[len(self) - batch_size * training_length :], dtype=np.float32
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:

tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list
def reset_field(self) -> None:
"""

def padded_to_batch(
self, pad_value: np.float = 0, dtype: np.dtype = np.float32
) -> Union[np.ndarray, List[np.ndarray]]:
"""
Converts this AgentBufferField (which is a List[BufferEntry]) into a numpy array
with first dimension equal to the length of this AgentBufferField. If this AgentBufferField
contains a List[List[BufferEntry]] (i.e., in the case of group observations), return a List
containing numpy arrays or tensors, of length equal to the maximum length of an entry. Missing
For entries with less than that length, the array will be padded with pad_value.
:param pad_value: Value to pad List AgentBufferFields, when there are less than the maximum
number of agents present.
:param dtype: Dtype of output numpy array.
:return: Numpy array or List of numpy arrays representing this AgentBufferField, where the first
dimension is equal to the length of the AgentBufferField.
"""
if len(self) > 0 and not isinstance(self[0], list):
return np.asanyarray(self, dytpe=dtype)
shape = None
for _entry in self:
# _entry could be an empty list if there are no group agents in this
# step. Find the first non-empty list and use that shape.
if _entry:
shape = _entry[0].shape
break
# If there were no groupmate agents in the entire batch, return an empty List.
if shape is None:
return []
# Convert to numpy array while padding with 0's
new_list = list(
map(
lambda x: np.asanyarray(x, dtype=dtype),
itertools.zip_longest(*self, fillvalue=np.full(shape, pad_value)),
)
)
return new_list
class AgentBuffer(MutableMapping):

163
ml-agents/mlagents/trainers/trajectory.py


from mlagents.trainers.torch.action_log_probs import LogProbsTuple
class GroupmateStatus(NamedTuple):
"""
Stores data related to an agent's teammate.
"""
obs: List[np.ndarray]
reward: float
action: ActionTuple
done: bool
class AgentExperience(NamedTuple):
obs: List[np.ndarray]
reward: float

prev_action: np.ndarray
interrupted: bool
memory: np.ndarray
group_status: List[GroupmateStatus]
group_reward: float
class ObsUtil:

return result
class GroupObsUtil:
@staticmethod
def get_name_at(index: int) -> AgentBufferKey:
"""
returns the name of the observation given the index of the observation
"""
return ObservationKeyPrefix.GROUP_OBSERVATION, index
@staticmethod
def get_name_at_next(index: int) -> AgentBufferKey:
"""
returns the name of the next team observation given the index of the observation
"""
return ObservationKeyPrefix.NEXT_GROUP_OBSERVATION, index
@staticmethod
def _transpose_list_of_lists(
list_list: List[List[np.ndarray]],
) -> List[List[np.ndarray]]:
return list(map(list, zip(*list_list)))
@staticmethod
def from_buffer(batch: AgentBuffer, num_obs: int) -> List[np.array]:
"""
Creates the list of observations from an AgentBuffer
"""
separated_obs: List[np.array] = []
for i in range(num_obs):
separated_obs.append(
batch[GroupObsUtil.get_name_at(i)].padded_to_batch(pad_value=np.nan)
)
# separated_obs contains a List(num_obs) of Lists(num_agents), we want to flip
# that and get a List(num_agents) of Lists(num_obs)
result = GroupObsUtil._transpose_list_of_lists(separated_obs)
return result
@staticmethod
def from_buffer_next(batch: AgentBuffer, num_obs: int) -> List[np.array]:
"""
Creates the list of observations from an AgentBuffer
"""
separated_obs: List[np.array] = []
for i in range(num_obs):
separated_obs.append(
batch[GroupObsUtil.get_name_at_next(i)].padded_to_batch(
pad_value=np.nan
)
)
# separated_obs contains a List(num_obs) of Lists(num_agents), we want to flip
# that and get a List(num_agents) of Lists(num_obs)
result = GroupObsUtil._transpose_list_of_lists(separated_obs)
return result
next_group_obs: List[List[np.ndarray]]
agent_id: str
behavior_id: str

agent_buffer_trajectory = AgentBuffer()
obs = self.steps[0].obs
for step, exp in enumerate(self.steps):
if step < len(self.steps) - 1:
is_last_step = step == len(self.steps) - 1
if not is_last_step:
next_obs = self.steps[step + 1].obs
else:
next_obs = self.next_obs

agent_buffer_trajectory[ObsUtil.get_name_at(i)].append(obs[i])
agent_buffer_trajectory[ObsUtil.get_name_at_next(i)].append(next_obs[i])
# Take care of teammate obs and actions
teammate_continuous_actions, teammate_discrete_actions, teammate_rewards = (
[],
[],
[],
)
for group_status in exp.group_status:
teammate_rewards.append(group_status.reward)
teammate_continuous_actions.append(group_status.action.continuous)
teammate_discrete_actions.append(group_status.action.discrete)
# Team actions
agent_buffer_trajectory[BufferKey.GROUP_CONTINUOUS_ACTION].append(
teammate_continuous_actions
)
agent_buffer_trajectory[BufferKey.GROUP_DISCRETE_ACTION].append(
teammate_discrete_actions
)
agent_buffer_trajectory[BufferKey.GROUPMATE_REWARDS].append(
teammate_rewards
)
agent_buffer_trajectory[BufferKey.GROUP_REWARD].append(exp.group_reward)
# Next actions
teammate_cont_next_actions = []
teammate_disc_next_actions = []
if not is_last_step:
next_exp = self.steps[step + 1]
for group_status in next_exp.group_status:
teammate_cont_next_actions.append(group_status.action.continuous)
teammate_disc_next_actions.append(group_status.action.discrete)
else:
for group_status in exp.group_status:
teammate_cont_next_actions.append(group_status.action.continuous)
teammate_disc_next_actions.append(group_status.action.discrete)
agent_buffer_trajectory[BufferKey.GROUP_NEXT_CONT_ACTION].append(
teammate_cont_next_actions
)
agent_buffer_trajectory[BufferKey.GROUP_NEXT_DISC_ACTION].append(
teammate_disc_next_actions
)
for i in range(num_obs):
ith_group_obs = []
for _group_status in exp.group_status:
# Assume teammates have same obs space
ith_group_obs.append(_group_status.obs[i])
agent_buffer_trajectory[GroupObsUtil.get_name_at(i)].append(
ith_group_obs
)
ith_group_obs_next = []
if is_last_step:
for _obs in self.next_group_obs:
ith_group_obs_next.append(_obs[i])
else:
next_group_status = self.steps[step + 1].group_status
for _group_status in next_group_status:
# Assume teammates have same obs space
ith_group_obs_next.append(_group_status.obs[i])
agent_buffer_trajectory[GroupObsUtil.get_name_at_next(i)].append(
ith_group_obs_next
)
agent_buffer_trajectory[BufferKey.GROUP_DONES].append(
[_status.done for _status in exp.group_status]
)
# Adds the log prob and action of continuous/discrete separately
agent_buffer_trajectory[BufferKey.CONTINUOUS_ACTION].append(

exp.action.discrete
)
cont_next_actions = np.zeros_like(exp.action.continuous)
disc_next_actions = np.zeros_like(exp.action.discrete)
if not is_last_step:
next_action = self.steps[step + 1].action
cont_next_actions = next_action.continuous
disc_next_actions = next_action.discrete
agent_buffer_trajectory[BufferKey.NEXT_CONT_ACTION].append(
cont_next_actions
)
agent_buffer_trajectory[BufferKey.NEXT_DISC_ACTION].append(
disc_next_actions
)
agent_buffer_trajectory[BufferKey.CONTINUOUS_LOG_PROBS].append(
exp.action_probs.continuous
)

Returns true if trajectory is terminated with a Done.
"""
return self.steps[-1].done
@property
def teammate_dones_reached(self) -> bool:
"""
Returns true if all teammates are done at the end of the trajectory.
Combine with done_reached to check if the whole team is done.
"""
return all(_status.done for _status in self.steps[-1].group_status)
@property
def interrupted(self) -> bool:

80
ml-agents/mlagents/trainers/tests/test_buffer.py


b = AgentBuffer()
for step in range(9):
b[ObsUtil.get_name_at(0)].append(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,
100 * fake_agent_id + 10 * step + 3,
]
np.array(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,
100 * fake_agent_id + 10 * step + 3,
],
dtype=np.float32,
)
[100 * fake_agent_id + 10 * step + 4, 100 * fake_agent_id + 10 * step + 5]
np.array(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,
],
dtype=np.float32,
)
)
b[BufferKey.GROUP_CONTINUOUS_ACTION].append(
[
np.array(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,
],
dtype=np.float32,
)
]
* 3
)
return b

agent_2_buffer = construct_fake_buffer(2)
agent_3_buffer = construct_fake_buffer(3)
# Test get_batch
assert_array(np.array(a), np.array([[171, 172, 173], [181, 182, 183]]))
assert_array(
np.array(a), np.array([[171, 172, 173], [181, 182, 183]], dtype=np.float32)
)
# Test get_batch
a = agent_2_buffer[ObsUtil.get_name_at(0)].get_batch(
batch_size=2, training_length=3, sequential=True
)

[261, 262, 263],
[271, 272, 273],
[281, 282, 283],
]
],
dtype=np.float32,
),
)
a = agent_2_buffer[ObsUtil.get_name_at(0)].get_batch(

]
),
)
# Test group entries return Lists of Lists
a = agent_2_buffer[BufferKey.GROUP_CONTINUOUS_ACTION].get_batch(
batch_size=2, training_length=1, sequential=True
)
for _group_entry in a:
assert len(_group_entry) == 3
agent_1_buffer.reset_agent()
assert agent_1_buffer.num_experiences == 0
update_buffer = AgentBuffer()

c = update_buffer.make_mini_batch(start=0, end=1)
assert c.keys() == update_buffer.keys()
# Make sure the values of c are AgentBufferField
for val in c.values():
assert isinstance(val, AgentBufferField)
def test_agentbufferfield():
# Test constructor
a = AgentBufferField([0, 1, 2])
for i, num in enumerate(a):
assert num == i
# Test indexing
assert a[i] == num
# Test slicing
b = a[1:3]
assert b == [1, 2]
assert isinstance(b, AgentBufferField)
# Test padding
c = AgentBufferField()
for _ in range(2):
c.append([np.array(1), np.array(2)])
for _ in range(2):
c.append([np.array(1)])
padded = c.padded_to_batch(pad_value=3)
assert np.array_equal(padded[0], np.array([1, 1, 1, 1]))
assert np.array_equal(padded[1], np.array([2, 2, 3, 3]))
def fakerandint(values):

100
ml-agents/mlagents/trainers/tests/test_agent_processor.py


from unittest import mock
import pytest
from typing import List
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
from mlagents.trainers.agent_processor import (

return mock_policy
def _create_action_info(num_agents: int, agent_ids: List[str]) -> ActionInfo:
fake_action_outputs = {
"action": ActionTuple(
continuous=np.array([[0.1]] * num_agents, dtype=np.float32)
),
"entropy": np.array([1.0], dtype=np.float32),
"learning_rate": 1.0,
"log_probs": LogProbsTuple(
continuous=np.array([[0.1]] * num_agents, dtype=np.float32)
),
}
fake_action_info = ActionInfo(
action=ActionTuple(continuous=np.array([[0.1]] * num_agents, dtype=np.float32)),
env_action=ActionTuple(
continuous=np.array([[0.1]] * num_agents, dtype=np.float32)
),
outputs=fake_action_outputs,
agent_ids=agent_ids,
)
return fake_action_info
@pytest.mark.parametrize("num_vis_obs", [0, 1, 2], ids=["vec", "1 viz", "2 viz"])
def test_agentprocessor(num_vis_obs):
policy = create_mock_policy()

stats_reporter=StatsReporter("testcat"),
)
fake_action_outputs = {
"action": ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
"entropy": np.array([1.0], dtype=np.float32),
"learning_rate": 1.0,
"log_probs": LogProbsTuple(
continuous=np.array([[0.1], [0.1]], dtype=np.float32)
),
}
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=2,
observation_specs=create_observation_specs_with_shapes(

)
fake_action_info = ActionInfo(
action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
env_action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
outputs=fake_action_outputs,
agent_ids=mock_decision_steps.agent_id,
)
fake_action_info = _create_action_info(2, mock_decision_steps.agent_id)
processor.publish_trajectory_queue(tqueue)
# This is like the initial state after the env reset
processor.add_experiences(

# Assert that the trajectory is of length 5
trajectory = tqueue.put.call_args_list[0][0][0]
assert len(trajectory.steps) == 5
# Make sure ungrouped agents don't have team obs
for step in trajectory.steps:
assert len(step.group_status) == 0
# Assert that the AgentProcessor is empty
assert len(processor.experience_buffers[0]) == 0

)
# Assert that the AgentProcessor is still empty
assert len(processor.experience_buffers[0]) == 0
def test_group_statuses():
policy = create_mock_policy()
tqueue = mock.Mock()
name_behavior_id = "test_brain_name"
processor = AgentProcessor(
policy,
name_behavior_id,
max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=4,
observation_specs=create_observation_specs_with_shapes([(8,)]),
action_spec=ActionSpec.create_continuous(2),
grouped=True,
)
fake_action_info = _create_action_info(4, mock_decision_steps.agent_id)
processor.publish_trajectory_queue(tqueue)
# This is like the initial state after the env reset
processor.add_experiences(
mock_decision_steps, mock_terminal_steps, 0, ActionInfo.empty()
)
for _ in range(2):
processor.add_experiences(
mock_decision_steps, mock_terminal_steps, 0, fake_action_info
)
# Make terminal steps for some dead agents
mock_decision_steps_2, mock_terminal_steps_2 = mb.create_mock_steps(
num_agents=2,
observation_specs=create_observation_specs_with_shapes([(8,)]),
action_spec=ActionSpec.create_continuous(2),
done=True,
grouped=True,
)
processor.add_experiences(
mock_decision_steps_2, mock_terminal_steps_2, 0, fake_action_info
)
fake_action_info = _create_action_info(4, mock_decision_steps.agent_id)
for _ in range(3):
processor.add_experiences(
mock_decision_steps, mock_terminal_steps, 0, fake_action_info
)
# Assert that four trajectories have been added to the Trainer
assert len(tqueue.put.call_args_list) == 4
# Last trajectory should be the longest
trajectory = tqueue.put.call_args_list[0][0][-1]
# Make sure trajectory has the right Groupmate Experiences
for step in trajectory.steps[0:3]:
assert len(step.group_status) == 3
# After 2 agents has died
for step in trajectory.steps[3:]:
assert len(step.group_status) == 1
def test_agent_deletion():

20
ml-agents/mlagents/trainers/tests/mock_brain.py


from mlagents.trainers.buffer import AgentBuffer, AgentBufferKey
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.trajectory import GroupmateStatus, Trajectory, AgentExperience
from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,

observation_specs: List[ObservationSpec],
action_spec: ActionSpec,
done: bool = False,
grouped: bool = False,
) -> Tuple[DecisionSteps, TerminalSteps]:
"""
Creates a mock Tuple[DecisionSteps, TerminalSteps] with observations.

reward = np.array(num_agents * [1.0], dtype=np.float32)
interrupted = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)
group_id = np.array(num_agents * [0], dtype=np.int32)
_gid = 1 if grouped else 0
group_id = np.array(num_agents * [_gid], dtype=np.int32)
group_reward = np.array(num_agents * [0.0], dtype=np.float32)
behavior_spec = BehaviorSpec(observation_specs, action_spec)
if done:

action_spec: ActionSpec,
max_step_complete: bool = False,
memory_size: int = 10,
num_other_agents_in_group: int = 0,
) -> Trajectory:
"""
Makes a fake trajectory of length length. If max_step_complete,

memory = np.ones(memory_size, dtype=np.float32)
agent_id = "test_agent"
behavior_id = "test_brain"
group_status = []
for _ in range(num_other_agents_in_group):
group_status.append(GroupmateStatus(obs, reward, action, done))
experience = AgentExperience(
obs=obs,
reward=reward,

prev_action=prev_action,
interrupted=max_step,
memory=memory,
group_status=group_status,
group_reward=0,
)
steps_list.append(experience)
obs = []

prev_action=prev_action,
interrupted=max_step_complete,
memory=memory,
group_status=group_status,
group_reward=0,
steps=steps_list, agent_id=agent_id, behavior_id=behavior_id, next_obs=obs
steps=steps_list,
agent_id=agent_id,
behavior_id=behavior_id,
next_obs=obs,
next_group_obs=[obs] * num_other_agents_in_group,
)

54
ml-agents/mlagents/trainers/tests/test_trajectory.py


import numpy as np
from mlagents.trainers.trajectory import GroupObsUtil
from mlagents.trainers.buffer import BufferKey, ObservationKeyPrefix
from mlagents.trainers.buffer import AgentBuffer, BufferKey, ObservationKeyPrefix
VEC_OBS_SIZE = 6
ACTION_SIZE = 4

length = 15
# These keys should be of type np.ndarray
wanted_keys = [
(ObservationKeyPrefix.OBSERVATION, 0),
(ObservationKeyPrefix.OBSERVATION, 1),

BufferKey.ACTION_MASK,
BufferKey.PREV_ACTION,
BufferKey.ENVIRONMENT_REWARDS,
BufferKey.GROUP_REWARD,
wanted_keys = set(wanted_keys)
# These keys should be of type List
wanted_group_keys = [
BufferKey.GROUPMATE_REWARDS,
BufferKey.GROUP_CONTINUOUS_ACTION,
BufferKey.GROUP_DISCRETE_ACTION,
BufferKey.GROUP_DONES,
BufferKey.GROUP_NEXT_CONT_ACTION,
BufferKey.GROUP_NEXT_DISC_ACTION,
]
wanted_keys = set(wanted_keys + wanted_group_keys)
trajectory = make_fake_trajectory(
length=length,
observation_specs=create_observation_specs_with_shapes(

num_other_agents_in_group=4,
)
agentbuffer = trajectory.to_agentbuffer()
seen_keys = set()

assert seen_keys == wanted_keys
assert seen_keys.issuperset(wanted_keys)
for _key in wanted_group_keys:
for step in agentbuffer[_key]:
assert len(step) == 4
def test_obsutil_group_from_buffer():
buff = AgentBuffer()
# Create some obs
for _ in range(3):
buff[GroupObsUtil.get_name_at(0)].append(3 * [np.ones((5,), dtype=np.float32)])
# Some agents have died
for _ in range(2):
buff[GroupObsUtil.get_name_at(0)].append(1 * [np.ones((5,), dtype=np.float32)])
# Get the group obs, which will be a List of Lists of np.ndarray, where each element is the same
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by
# NaNs.
gobs = GroupObsUtil.from_buffer(buff, 1)
# Agent 0 is full
agent_0_obs = gobs[0]
for obs in agent_0_obs:
assert obs.shape == (buff.num_experiences, 5)
assert not np.isnan(obs).any()
agent_1_obs = gobs[1]
for obs in agent_1_obs:
assert obs.shape == (buff.num_experiences, 5)
for i, _exp_obs in enumerate(obs):
if i >= 3:
assert np.isnan(_exp_obs).all()
else:
assert not np.isnan(_exp_obs).any()

12
ml-agents/mlagents/trainers/torch/utils.py


return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
@staticmethod
def list_to_tensor_list(
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
"""
Converts a list of numpy arrays into a list of tensors. MUCH faster than
calling as_tensor on the list directly.
"""
return [
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
]
@staticmethod
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""
Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will

86
ml-agents/mlagents/trainers/torch/agent_action.py


from typing import List, Optional, NamedTuple
import itertools
import numpy as np
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer, BufferKey

discrete_list: Optional[List[torch.Tensor]]
@property
def discrete_tensor(self):
def discrete_tensor(self) -> torch.Tensor:
return torch.stack(self.discrete_list, dim=-1)
if self.discrete_list is not None and len(self.discrete_list) > 0:
return torch.stack(self.discrete_list, dim=-1)
else:
return torch.empty(0)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""

discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return AgentAction(continuous, discrete)
@staticmethod
def _group_agent_action_from_buffer(
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey
) -> List["AgentAction"]:
"""
Extracts continuous and discrete groupmate actions, as specified by BufferKey, and
returns a List of AgentActions that correspond to the groupmate's actions. List will
be of length equal to the maximum number of groupmates in the buffer. Any spots where
there are less agents than maximum, the actions will be padded with 0's.
"""
continuous_tensors: List[torch.Tensor] = []
discrete_tensors: List[torch.Tensor] = []
if cont_action_key in buff:
padded_batch = buff[cont_action_key].padded_to_batch()
continuous_tensors = [
ModelUtils.list_to_tensor(arr) for arr in padded_batch
]
if disc_action_key in buff:
padded_batch = buff[disc_action_key].padded_to_batch(dtype=np.long)
discrete_tensors = [
ModelUtils.list_to_tensor(arr, dtype=torch.long) for arr in padded_batch
]
actions_list = []
for _cont, _disc in itertools.zip_longest(
continuous_tensors, discrete_tensors, fillvalue=None
):
if _disc is not None:
_disc = [_disc[..., i] for i in range(_disc.shape[-1])]
actions_list.append(AgentAction(_cont, _disc))
return actions_list
@staticmethod
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses next group continuous and discrete action fields in an AgentBuffer
and constructs a padded List of AgentActions that represent the group agent actions.
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
:param buff: AgentBuffer of a batch or trajectory
:return: List of groupmate's AgentActions
"""
return AgentAction._group_agent_action_from_buffer(
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION
)
@staticmethod
def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses next group continuous and discrete action fields in an AgentBuffer
and constructs a padded List of AgentActions that represent the next group agent actions.
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
:param buff: AgentBuffer of a batch or trajectory
:return: List of groupmate's AgentActions
"""
return AgentAction._group_agent_action_from_buffer(
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION
)
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
"""
Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete).
Discrete actions are converted into one-hot and concatenated with continuous actions.
:param discrete_branches: List of sizes for discrete actions.
:return: Tensor of flattened actions.
"""
# if there are any discrete actions, create one-hot
if self.discrete_list is not None and self.discrete_list:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
else:
discrete_oh = torch.empty(0)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)

11
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta


fileFormatVersion: 2
guid: 5661ffdb6c7704e84bc785572dcd5bd1
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

63
ml-agents/mlagents/trainers/tests/torch/test_agent_action.py


import numpy as np
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.torch.agent_action import AgentAction
def test_agent_action_group_from_buffer():
buff = AgentBuffer()
# Create some actions
for _ in range(3):
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
3 * [np.ones((5,), dtype=np.float32)]
)
buff[BufferKey.GROUP_DISCRETE_ACTION].append(
3 * [np.ones((4,), dtype=np.float32)]
)
# Some agents have died
for _ in range(2):
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
1 * [np.ones((5,), dtype=np.float32)]
)
buff[BufferKey.GROUP_DISCRETE_ACTION].append(
1 * [np.ones((4,), dtype=np.float32)]
)
# Get the group actions, which will be a List of Lists of AgentAction, where each element is the same
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by
# NaNs.
gact = AgentAction.group_from_buffer(buff)
# Agent 0 is full
agent_0_act = gact[0]
assert agent_0_act.continuous_tensor.shape == (buff.num_experiences, 5)
assert agent_0_act.discrete_tensor.shape == (buff.num_experiences, 4)
agent_1_act = gact[1]
assert agent_1_act.continuous_tensor.shape == (buff.num_experiences, 5)
assert agent_1_act.discrete_tensor.shape == (buff.num_experiences, 4)
assert (agent_1_act.continuous_tensor[0:3] > 0).all()
assert (agent_1_act.continuous_tensor[3:] == 0).all()
assert (agent_1_act.discrete_tensor[0:3] > 0).all()
assert (agent_1_act.discrete_tensor[3:] == 0).all()
def test_to_flat():
# Both continuous and discrete
aa = AgentAction(
torch.tensor([[1.0, 1.0, 1.0]]), [torch.tensor([2]), torch.tensor([1])]
)
flattened_actions = aa.to_flat([3, 3])
assert torch.eq(
flattened_actions, torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 0]])
).all()
# Just continuous
aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]), None)
flattened_actions = aa.to_flat([])
assert torch.eq(flattened_actions, torch.tensor([1, 1, 1])).all()
# Just discrete
aa = AgentAction(torch.tensor([]), [torch.tensor([2]), torch.tensor([1])])
flattened_actions = aa.to_flat([3, 3])
assert torch.eq(flattened_actions, torch.tensor([0, 0, 1, 0, 1, 0])).all()

11
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta


fileFormatVersion: 2
guid: 5661ffdb6c7704e84bc785572dcd5bd1
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存