浏览代码

MultiAgentGroup Interface (#4923)

* add SimpleMultiAgentGroup

* add group reward field to agent and proto
/develop/gail-srl-hack
GitHub 3 年前
当前提交
ddb01eb2
共有 20 个文件被更改,包括 618 次插入28 次删除
  1. 65
      com.unity.ml-agents/Runtime/Agent.cs
  2. 2
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 67
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  4. 7
      gym-unity/gym_unity/tests/test_gym.py
  5. 22
      ml-agents-envs/mlagents_envs/base_env.py
  6. 18
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  7. 8
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  8. 30
      ml-agents-envs/mlagents_envs/rpc_utils.py
  9. 4
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  10. 10
      ml-agents/mlagents/trainers/tests/mock_brain.py
  11. 65
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  12. 2
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  13. 26
      com.unity.ml-agents/Runtime/IMultiAgentGroup.cs
  14. 11
      com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta
  15. 13
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs
  16. 11
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta
  17. 143
      com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs
  18. 11
      com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta
  19. 120
      com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs
  20. 11
      com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta

65
com.unity.ml-agents/Runtime/Agent.cs


public float reward;
/// <summary>
/// The current group reward received by the agent.
/// </summary>
public float groupReward;
/// <summary>
/// Whether the agent is done or not.
/// </summary>
public bool done;

/// to separate between different agents in the environment.
/// </summary>
public int episodeId;
/// <summary>
/// MultiAgentGroup identifier.
/// </summary>
public int groupId;
public void ClearActions()
{

/// Additionally, the magnitude of the reward should not exceed 1.0
float m_Reward;
/// Represents the group reward the agent accumulated during the current step.
float m_GroupReward;
/// Keeps track of the cumulative reward in this episode.
float m_CumulativeReward;

/// </summary>
float[] m_LegacyHeuristicCache;
/// Currect MultiAgentGroup ID. Default to 0 (meaning no group)
int m_GroupId;
/// Delegate for the agent to unregister itself from the MultiAgentGroup without cyclic reference
/// between agent and the group
internal event Action<Agent> OnAgentDisabled;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html

new int[m_ActuatorManager.NumDiscreteActions]
);
m_Info.groupId = m_GroupId;
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
// To avoid the Agent resetting twice, the Agents will not begin their

NotifyAgentDone(DoneReason.Disabled);
}
m_Brain?.Dispose();
OnAgentDisabled?.Invoke(this);
m_Initialized = false;
}

}
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.groupReward = m_GroupReward;
m_Info.groupId = m_GroupId;
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

}
m_Reward = 0f;
m_GroupReward = 0f;
m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;

m_CumulativeReward += increment;
}
internal void SetGroupReward(float reward)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward));
#endif
m_GroupReward = reward;
}
internal void AddGroupReward(float increment)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward));
#endif
m_GroupReward += increment;
}
/// <summary>
/// Retrieves the episode reward for the Agent.
/// </summary>

m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.groupReward = m_GroupReward;
m_Info.groupId = m_GroupId;
using (TimerStack.Instance.Scoped("RequestDecision"))
{

{
SendInfoToBrain();
m_Reward = 0f;
m_GroupReward = 0f;
m_RequestDecision = false;
}
}

var actions = m_Brain?.DecideAction() ?? new ActionBuffers();
m_Info.CopyActions(actions);
m_ActuatorManager.UpdateActions(actions);
}
internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup)
{
if (multiAgentGroup == null)
{
m_GroupId = 0;
}
else
{
var newGroupId = multiAgentGroup.GetId();
if (m_GroupId == 0 || m_GroupId == newGroupId)
{
m_GroupId = newGroupId;
}
else
{
throw new UnityAgentsException("Agent is already registered with a group. Unregister it first.");
}
}
}
}
}

2
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
GroupReward = ai.groupReward,
GroupId = ai.groupId,
};
if (ai.discreteActionMasks != null)

67
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB",
"KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI",
"BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null)
}));
}
#endregion

id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
groupId_ = other.groupId_;
groupReward_ = other.groupReward_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return observations_; }
}
/// <summary>Field number for the "group_id" field.</summary>
public const int GroupIdFieldNumber = 14;
private int groupId_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int GroupId {
get { return groupId_; }
set {
groupId_ = value;
}
}
/// <summary>Field number for the "group_reward" field.</summary>
public const int GroupRewardFieldNumber = 15;
private float groupReward_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public float GroupReward {
get { return groupReward_; }
set {
groupReward_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoProto);

if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (GroupId != other.GroupId) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (GroupId != 0) hash ^= GroupId.GetHashCode();
if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (GroupId != 0) {
output.WriteRawTag(112);
output.WriteInt32(GroupId);
}
if (GroupReward != 0F) {
output.WriteRawTag(125);
output.WriteFloat(GroupReward);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (GroupId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId);
}
if (GroupReward != 0F) {
size += 1 + 4;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.GroupId != 0) {
GroupId = other.GroupId;
}
if (other.GroupReward != 0F) {
GroupReward = other.GroupReward;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 106: {
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 112: {
GroupId = input.ReadInt32();
break;
}
case 125: {
GroupReward = input.ReadFloat();
break;
}
}

7
gym-unity/gym_unity/tests/test_gym.py


] * number_visual_observations
rewards = np.array(num_agents * [1.0])
agents = np.array(range(0, num_agents))
return DecisionSteps(obs, rewards, agents, None), TerminalSteps.empty(specs)
group_id = np.array(num_agents * [0])
group_rewards = np.array(num_agents * [0.0])
return (
DecisionSteps(obs, rewards, agents, None, group_id, group_rewards),
TerminalSteps.empty(specs),
)
def setup_mock_unityenvironment(mock_env, mock_spec, mock_decision, mock_termination):

22
ml-agents-envs/mlagents_envs/base_env.py


reward: float
agent_id: AgentId
action_mask: Optional[List[np.ndarray]]
group_id: int
group_reward: float
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask):
def __init__(self, obs, reward, agent_id, action_mask, group_id, group_reward):
self.group_id: np.ndarray = group_id
self.group_reward: np.ndarray = group_reward
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
group_id = self.group_id[agent_index]
group_id=group_id,
group_reward=self.group_reward[agent_index],
)
def __iter__(self) -> Iterator[Any]:

reward=np.zeros(0, dtype=np.float32),
agent_id=np.zeros(0, dtype=np.int32),
action_mask=None,
group_id=np.zeros(0, dtype=np.int32),
group_reward=np.zeros(0, dtype=np.float32),
)

reward: float
interrupted: bool
agent_id: AgentId
group_id: int
group_reward: float
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id):
def __init__(self, obs, reward, interrupted, agent_id, group_id, group_reward):
self.group_id: np.ndarray = group_id
self.group_reward: np.ndarray = group_reward
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
group_id = self.group_id[agent_index]
group_id=group_id,
group_reward=self.group_reward[agent_index],
)
def __iter__(self) -> Iterator[Any]:

reward=np.zeros(0, dtype=np.float32),
interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
group_id=np.zeros(0, dtype=np.int32),
group_reward=np.zeros(0, dtype=np.float32),
)

18
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xd1\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProtoJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xf9\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x10\n\x08group_id\x18\x0e \x01(\x05\x12\x14\n\x0cgroup_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='group_id', full_name='communicator_objects.AgentInfoProto.group_id', index=6,
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='group_reward', full_name='communicator_objects.AgentInfoProto.group_reward', index=7,
number=15, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=132,
serialized_end=341,
serialized_end=381,
)
_AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO

8
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
group_id = ... # type: builtin___int
group_reward = ... # type: builtin___float
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
group_id : typing___Optional[builtin___int] = None,
group_reward : typing___Optional[builtin___float] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"group_id",u"group_reward",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"group_id",b"group_id",u"group_reward",b"group_reward",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ...

30
ml-agents-envs/mlagents_envs/rpc_utils.py


[agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32
)
decision_group_rewards = np.array(
[agent_info.group_reward for agent_info in decision_agent_info_list],
dtype=np.float32,
)
terminal_group_rewards = np.array(
[agent_info.group_reward for agent_info in terminal_agent_info_list],
dtype=np.float32,
)
_raise_on_nan_and_inf(decision_group_rewards, "group_rewards")
_raise_on_nan_and_inf(terminal_group_rewards, "group_rewards")
decision_group_id = [agent_info.group_id for agent_info in decision_agent_info_list]
terminal_group_id = [agent_info.group_id for agent_info in terminal_agent_info_list]
max_step = np.array(
[agent_info.max_step_reached for agent_info in terminal_agent_info_list],

action_mask = np.split(action_mask, indices, axis=1)
return (
DecisionSteps(
decision_obs_list, decision_rewards, decision_agent_id, action_mask
decision_obs_list,
decision_rewards,
decision_agent_id,
action_mask,
decision_group_id,
decision_group_rewards,
TerminalSteps(terminal_obs_list, terminal_rewards, max_step, terminal_agent_id),
TerminalSteps(
terminal_obs_list,
terminal_rewards,
max_step,
terminal_agent_id,
terminal_group_id,
terminal_group_rewards,
),
)

4
ml-agents-envs/mlagents_envs/tests/test_steps.py


reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
action_mask=[np.zeros((3, 4), dtype=np.bool)],
group_id=np.array(range(3), dtype=np.int32),
group_reward=np.array(range(3), dtype=np.float32),
)
assert ds.agent_id_to_index[10] == 0

reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
interrupted=np.array([1, 0, 1], dtype=np.bool),
group_id=np.array(range(3), dtype=np.int32),
group_reward=np.array(range(3), dtype=np.float32),
)
assert ts.agent_id_to_index[10] == 0

10
ml-agents/mlagents/trainers/tests/mock_brain.py


reward = np.array(num_agents * [1.0], dtype=np.float32)
interrupted = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)
group_id = np.array(num_agents * [0], dtype=np.int32)
group_reward = np.array(num_agents * [0.0], dtype=np.float32)
TerminalSteps(obs_list, reward, interrupted, agent_id),
TerminalSteps(
obs_list, reward, interrupted, agent_id, group_id, group_reward
),
DecisionSteps(obs_list, reward, agent_id, action_mask),
DecisionSteps(
obs_list, reward, agent_id, action_mask, group_id, group_reward
),
TerminalSteps.empty(behavior_spec),
)

65
ml-agents/mlagents/trainers/tests/simple_test_envs.py


self.agent_id[name] = self.agent_id[name] + 1
def _make_batched_step(
self, name: str, done: bool, reward: float
self, name: str, done: bool, reward: float, group_reward: float
m_group_id = np.array([0], dtype=np.int32)
m_group_reward = np.array([group_reward], dtype=np.float32)
decision_step = DecisionSteps(m_vector_obs, m_reward, m_agent_id, action_mask)
decision_step = DecisionSteps(
m_vector_obs, m_reward, m_agent_id, action_mask, m_group_id, m_group_reward
)
terminal_step = TerminalSteps.empty(self.behavior_spec)
if done:
self.final_rewards[name].append(self.rewards[name])

new_done,
new_agent_id,
new_action_mask,
new_group_id,
new_group_reward,
new_vector_obs, new_reward, new_agent_id, new_action_mask
new_vector_obs,
new_reward,
new_agent_id,
new_action_mask,
new_group_id,
new_group_reward,
m_vector_obs, m_reward, np.array([False], dtype=np.bool), m_agent_id
m_vector_obs,
m_reward,
np.array([False], dtype=np.bool),
m_agent_id,
m_group_id,
m_group_reward,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return new_reward, new_done, new_agent_id, new_action_mask
new_group_id = np.array([0], dtype=np.int32)
new_group_reward = np.array([0.0], dtype=np.float32)
return (
new_reward,
new_done,
new_agent_id,
new_action_mask,
new_group_id,
new_group_reward,
)
def step(self) -> None:
assert all(action is not None for action in self.action.values())

reward = self._compute_reward(name, done)
self.rewards[name] += reward
self.step_result[name] = self._make_batched_step(name, done, reward)
self.step_result[name] = self._make_batched_step(name, done, reward, 0.0)
self.step_result[name] = self._make_batched_step(name, False, 0.0)
self.step_result[name] = self._make_batched_step(name, False, 0.0, 0.0)
@property
def reset_parameters(self) -> Dict[str, str]:

self.num_show_steps = 2
def _make_batched_step(
self, name: str, done: bool, reward: float
self, name: str, done: bool, reward: float, group_reward: float
) -> Tuple[DecisionSteps, TerminalSteps]:
recurrent_obs_val = (
self.goal[name] if self.step_count[name] <= self.num_show_steps else 0

m_agent_id = np.array([self.agent_id[name]], dtype=np.int32)
m_group_id = np.array([0], dtype=np.int32)
m_group_reward = np.array([group_reward], dtype=np.float32)
decision_step = DecisionSteps(m_vector_obs, m_reward, m_agent_id, action_mask)
decision_step = DecisionSteps(
m_vector_obs, m_reward, m_agent_id, action_mask, m_group_id, m_group_reward
)
terminal_step = TerminalSteps.empty(self.behavior_spec)
if done:
self.final_rewards[name].append(self.rewards[name])

new_done,
new_agent_id,
new_action_mask,
new_group_id,
new_group_reward,
new_vector_obs, new_reward, new_agent_id, new_action_mask
new_vector_obs,
new_reward,
new_agent_id,
new_action_mask,
new_group_id,
new_group_reward,
m_vector_obs, m_reward, np.array([False], dtype=np.bool), m_agent_id
m_vector_obs,
m_reward,
np.array([False], dtype=np.bool),
m_agent_id,
m_group_id,
m_group_reward,
)
return (decision_step, terminal_step)

2
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


repeated bool action_mask = 11;
reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
int32 group_id = 14;
float group_reward = 15;
}

26
com.unity.ml-agents/Runtime/IMultiAgentGroup.cs


namespace Unity.MLAgents
{
/// <summary>
/// MultiAgentGroup interface for grouping agents to support multi-agent training.
/// </summary>
public interface IMultiAgentGroup
{
/// <summary>
/// Get the ID of MultiAgentGroup.
/// </summary>
/// <returns>
/// MultiAgentGroup ID.
/// </returns>
int GetId();
/// <summary>
/// Register agent to the MultiAgentGroup.
/// </summary>
void RegisterAgent(Agent agent);
/// <summary>
/// Unregister agent from the MultiAgentGroup.
/// </summary>
void UnregisterAgent(Agent agent);
}
}

11
com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta


fileFormatVersion: 2
guid: 3744ac27d956e43e1a39c7ba2550ab82
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

13
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class MultiAgentGroupIdCounter
{
static int s_Counter;
public static int GetGroupId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

11
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta


fileFormatVersion: 2
guid: 5661ffdb6c7704e84bc785572dcd5bd1
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

143
com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs


using System;
using System.Linq;
using System.Collections.Generic;
namespace Unity.MLAgents
{
/// <summary>
/// A basic class implementation of MultiAgentGroup.
/// </summary>
internal class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable
{
readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId();
HashSet<Agent> m_Agents = new HashSet<Agent>();
public virtual void Dispose()
{
while (m_Agents.Count > 0)
{
UnregisterAgent(m_Agents.First());
}
}
/// <inheritdoc />
public virtual void RegisterAgent(Agent agent)
{
if (!m_Agents.Contains(agent))
{
agent.SetMultiAgentGroup(this);
m_Agents.Add(agent);
agent.OnAgentDisabled += UnregisterAgent;
}
}
/// <inheritdoc />
public virtual void UnregisterAgent(Agent agent)
{
if (m_Agents.Contains(agent))
{
agent.SetMultiAgentGroup(null);
m_Agents.Remove(agent);
agent.OnAgentDisabled -= UnregisterAgent;
}
}
/// <inheritdoc />
public int GetId()
{
return m_Id;
}
/// <summary>
/// Get list of all agents currently registered to this MultiAgentGroup.
/// </summary>
/// <returns>
/// List of agents registered to the MultiAgentGroup.
/// </returns>
public IReadOnlyCollection<Agent> GetRegisteredAgents()
{
return (IReadOnlyCollection<Agent>)m_Agents;
}
/// <summary>
/// Increments the group rewards for all agents in this MultiAgentGroup.
/// </summary>
/// <remarks>
/// This function increases or decreases the group rewards by a given amount for all agents
/// in the group. Use <see cref="SetGroupReward(float)"/> to set the group reward assigned
/// to the current step with a specific value rather than increasing or decreasing it.
///
/// A positive group reward indicates the whole group's accomplishments or desired behaviors.
/// Every agent in the group will receive the same group reward no matter whether the
/// agent's act directly leads to the reward. Group rewards are meant to reinforce agents
/// to act in the group's best interest instead of individual ones.
/// Group rewards are treated differently than individual agent rewards during training, so
/// calling AddGroupReward() is not equivalent to calling agent.AddReward() on each agent in the group.
/// </remarks>
/// <param name="reward">Incremental group reward value.</param>
public void AddGroupReward(float reward)
{
foreach (var agent in m_Agents)
{
agent.AddGroupReward(reward);
}
}
/// <summary>
/// Set the group rewards for all agents in this MultiAgentGroup.
/// </summary>
/// <remarks>
/// This function replaces any group rewards given during the current step for all agents in the group.
/// Use <see cref="AddGroupReward(float)"/> to incrementally change the group reward rather than
/// overriding it.
///
/// A positive group reward indicates the whole group's accomplishments or desired behaviors.
/// Every agent in the group will receive the same group reward no matter whether the
/// agent's act directly leads to the reward. Group rewards are meant to reinforce agents
/// to act in the group's best interest instead of indivisual ones.
/// Group rewards are treated differently than individual agent rewards during training, so
/// calling SetGroupReward() is not equivalent to calling agent.SetReward() on each agent in the group.
/// </remarks>
/// <param name="reward">The new value of the group reward.</param>
public void SetGroupReward(float reward)
{
foreach (var agent in m_Agents)
{
agent.SetGroupReward(reward);
}
}
/// <summary>
/// End episodes for all agents in this MultiAgentGroup.
/// </summary>
/// <remarks>
/// This should be used when the episode can no longer continue, such as when the group
/// reaches the goal or fails at the task.
/// </remarks>
public void EndGroupEpisode()
{
foreach (var agent in m_Agents)
{
agent.EndEpisode();
}
}
/// <summary>
/// Indicate that the episode is over but not due to the "fault" of the group.
/// This has the same end result as calling <see cref="EndGroupEpisode"/>, but has a
/// slightly different effect on training.
/// </summary>
/// <remarks>
/// This should be used when the episode could continue, but has gone on for
/// a sufficient number of steps, such as if the environment hits some maximum number of steps.
/// </remarks>
public void GroupEpisodeInterrupted()
{
foreach (var agent in m_Agents)
{
agent.EpisodeInterrupted();
}
}
}
}

11
com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta


fileFormatVersion: 2
guid: 3454e3c3c70964dca93b63ee4b650095
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

120
com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs


using Unity.MLAgents;
using System;
using System.Reflection;
using NUnit.Framework;
using UnityEngine;
using Unity;
namespace Unity.MLAgents.Tests
{
public class MultiAgentGroupTests
{
class TestAgent : Agent
{
internal int _GroupId
{
get
{
return (int)typeof(Agent).GetField("m_GroupId", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
}
internal float _GroupReward
{
get
{
return (float)typeof(Agent).GetField("m_GroupReward", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
}
internal Action<Agent> _OnAgentDisabledActions
{
get
{
return (Action<Agent>)typeof(Agent).GetField("OnAgentDisabled", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
}
}
[Test]
public void TestRegisteredAgentGroupId()
{
var agentGo = new GameObject("TestAgent");
agentGo.AddComponent<TestAgent>();
var agent = agentGo.GetComponent<TestAgent>();
// test register
SimpleMultiAgentGroup agentGroup1 = new SimpleMultiAgentGroup();
agentGroup1.RegisterAgent(agent);
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId);
Assert.IsNotNull(agent._OnAgentDisabledActions);
// should not be able to registered to multiple groups
SimpleMultiAgentGroup agentGroup2 = new SimpleMultiAgentGroup();
Assert.Throws<UnityAgentsException>(
() => agentGroup2.RegisterAgent(agent));
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId);
// test unregister
agentGroup1.UnregisterAgent(agent);
Assert.AreEqual(0, agent._GroupId);
Assert.IsNull(agent._OnAgentDisabledActions);
// test register to another group after unregister
agentGroup2.RegisterAgent(agent);
Assert.AreEqual(agentGroup2.GetId(), agent._GroupId);
Assert.IsNotNull(agent._OnAgentDisabledActions);
}
[Test]
public void TestRegisterMultipleAgent()
{
var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var agentGo2 = new GameObject("TestAgent");
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
SimpleMultiAgentGroup agentGroup = new SimpleMultiAgentGroup();
agentGroup.RegisterAgent(agent1); // register
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1);
agentGroup.UnregisterAgent(agent2); // unregister non-member agent
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1);
agentGroup.UnregisterAgent(agent1); // unregister
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 0);
agentGroup.RegisterAgent(agent1);
agentGroup.RegisterAgent(agent1); // duplicated register
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1);
agentGroup.RegisterAgent(agent2); // register another
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 2);
// test add/set group rewards
agentGroup.AddGroupReward(0.1f);
Assert.AreEqual(0.1f, agent1._GroupReward);
agentGroup.AddGroupReward(0.5f);
Assert.AreEqual(0.6f, agent1._GroupReward);
agentGroup.SetGroupReward(0.3f);
Assert.AreEqual(0.3f, agent1._GroupReward);
// unregistered agent should not receive group reward
agentGroup.UnregisterAgent(agent1);
agentGroup.AddGroupReward(0.2f);
Assert.AreEqual(0.3f, agent1._GroupReward);
Assert.AreEqual(0.5f, agent2._GroupReward);
// dispose group should automatically unregister all
agentGroup.Dispose();
Assert.AreEqual(0, agent1._GroupId);
Assert.AreEqual(0, agent2._GroupId);
}
[Test]
public void TestGroupIdCounter()
{
SimpleMultiAgentGroup group1 = new SimpleMultiAgentGroup();
SimpleMultiAgentGroup group2 = new SimpleMultiAgentGroup();
// id should be unique
Assert.AreNotEqual(group1.GetId(), group2.GetId());
}
}
}

11
com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta


fileFormatVersion: 2
guid: ef0158fde748d478ca5ee3bbe22a4c9e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存