浏览代码

change name TeamManager to MultiAgentGroup

/develop/superpush/int
Ruo-Ping Dong 3 年前
当前提交
918c2dcd
共有 21 个文件被更改,包括 290 次插入133 次删除
  1. 52
      com.unity.ml-agents/Runtime/Agent.cs
  2. 4
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 74
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  4. 2
      com.unity.ml-agents/Runtime/IMultiAgentGroup.cs
  5. 2
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta
  6. 44
      ml-agents-envs/mlagents_envs/base_env.py
  7. 8
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  8. 12
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  9. 28
      ml-agents-envs/mlagents_envs/rpc_utils.py
  10. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  11. 2
      com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta
  12. 13
      com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs
  13. 165
      com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs
  14. 13
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs
  15. 0
      /com.unity.ml-agents.extensions/Runtime/MultiAgent.meta
  16. 0
      /com.unity.ml-agents/Runtime/IMultiAgentGroup.cs
  17. 0
      /com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta
  18. 0
      /com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta
  19. 0
      /com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta

52
com.unity.ml-agents/Runtime/Agent.cs


public float reward;
/// <summary>
/// The current team reward received by the agent.
/// The current group reward received by the agent.
public float teamReward;
public float groupReward;
/// <summary>
/// Whether the agent is done or not.

public int episodeId;
/// <summary>
/// Team Manager identifier.
/// MultiAgentGroup identifier.
public int teamManagerId;
public int groupId;
public void ClearActions()
{

/// Additionally, the magnitude of the reward should not exceed 1.0
float m_Reward;
/// Represents the team reward the agent accumulated during the current step.
float m_TeamReward;
/// Represents the group reward the agent accumulated during the current step.
float m_GroupReward;
/// Keeps track of the cumulative reward in this episode.
float m_CumulativeReward;

/// </summary>
float[] m_LegacyHeuristicCache;
int m_TeamManagerID;
int m_GroupId;
internal event Action<Agent> UnregisterFromTeamManager;
internal event Action<Agent> UnregisterFromGroup;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.

new int[m_ActuatorManager.NumDiscreteActions]
);
m_Info.teamManagerId = m_TeamManagerID;
m_Info.groupId = m_GroupId;
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

NotifyAgentDone(DoneReason.Disabled);
}
m_Brain?.Dispose();
UnregisterFromTeamManager?.Invoke(this);
UnregisterFromGroup?.Invoke(this);
m_Initialized = false;
}

}
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.teamReward = m_TeamReward;
m_Info.groupReward = m_GroupReward;
m_Info.teamManagerId = m_TeamManagerID;
m_Info.groupId = m_GroupId;
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

}
m_Reward = 0f;
m_TeamReward = 0f;
m_GroupReward = 0f;
m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;

m_CumulativeReward += increment;
}
internal void SetTeamReward(float reward)
internal void SetGroupReward(float reward)
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward));
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward));
m_TeamReward = reward;
m_GroupReward = reward;
internal void AddTeamReward(float increment)
internal void AddGroupReward(float increment)
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddTeamReward));
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward));
m_TeamReward += increment;
m_GroupReward += increment;
}
/// <summary>

m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.teamReward = m_TeamReward;
m_Info.groupReward = m_GroupReward;
m_Info.teamManagerId = m_TeamManagerID;
m_Info.groupId = m_GroupId;
using (TimerStack.Instance.Scoped("RequestDecision"))
{

{
SendInfoToBrain();
m_Reward = 0f;
m_TeamReward = 0f;
m_GroupReward = 0f;
m_RequestDecision = false;
}
}

m_ActuatorManager.UpdateActions(actions);
}
internal void SetTeamManager(ITeamManager teamManager)
internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup)
// unregister current TeamManager if this agent has been assigned one before
UnregisterFromTeamManager?.Invoke(this);
// unregister from current group if this agent has been assigned one before
UnregisterFromGroup?.Invoke(this);
m_TeamManagerID = teamManager.GetId();
m_GroupId = multiAgentGroup.GetId();
}
}
}

4
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
TeamReward = ai.teamReward,
GroupReward = ai.groupReward,
TeamManagerId = ai.teamManagerId,
GroupId = ai.groupId,
};
if (ai.discreteActionMasks != null)

74
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIv8BCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgFEhMKC3RlYW1fcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQI",
"AxAESgQIBBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50",
"cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB",
"KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI",
"BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId", "TeamReward" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null)
}));
}
#endregion

id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
teamManagerId_ = other.teamManagerId_;
teamReward_ = other.teamReward_;
groupId_ = other.groupId_;
groupReward_ = other.groupReward_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return observations_; }
}
/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private int teamManagerId_;
/// <summary>Field number for the "group_id" field.</summary>
public const int GroupIdFieldNumber = 14;
private int groupId_;
public int TeamManagerId {
get { return teamManagerId_; }
public int GroupId {
get { return groupId_; }
teamManagerId_ = value;
groupId_ = value;
/// <summary>Field number for the "team_reward" field.</summary>
public const int TeamRewardFieldNumber = 15;
private float teamReward_;
/// <summary>Field number for the "group_reward" field.</summary>
public const int GroupRewardFieldNumber = 15;
private float groupReward_;
public float TeamReward {
get { return teamReward_; }
public float GroupReward {
get { return groupReward_; }
teamReward_ = value;
groupReward_ = value;
}
}

if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (TeamManagerId != other.TeamManagerId) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TeamReward, other.TeamReward)) return false;
if (GroupId != other.GroupId) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (TeamReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TeamReward);
if (GroupId != 0) hash ^= GroupId.GetHashCode();
if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId != 0) {
if (GroupId != 0) {
output.WriteInt32(TeamManagerId);
output.WriteInt32(GroupId);
if (TeamReward != 0F) {
if (GroupReward != 0F) {
output.WriteFloat(TeamReward);
output.WriteFloat(GroupReward);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
if (GroupId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId);
if (TeamReward != 0F) {
if (GroupReward != 0F) {
size += 1 + 4;
}
if (_unknownFields != null) {

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
if (other.GroupId != 0) {
GroupId = other.GroupId;
if (other.TeamReward != 0F) {
TeamReward = other.TeamReward;
if (other.GroupReward != 0F) {
GroupReward = other.GroupReward;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
}
case 112: {
TeamManagerId = input.ReadInt32();
GroupId = input.ReadInt32();
TeamReward = input.ReadFloat();
GroupReward = input.ReadFloat();
break;
}
}

2
com.unity.ml-agents/Runtime/IMultiAgentGroup.cs


namespace Unity.MLAgents
{
public interface ITeamManager
public interface IMultiAgentGroup
{
int GetId();

2
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta


fileFormatVersion: 2
guid: 8b061f82569af4ffba715297f77a95ab
guid: 5661ffdb6c7704e84bc785572dcd5bd1
MonoImporter:
externalObjects: {}
serializedVersion: 2

44
ml-agents-envs/mlagents_envs/base_env.py


obs: List[np.ndarray]
reward: float
team_reward: float
group_reward: float
team_manager_id: int
group_id: int
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(
self, obs, reward, team_reward, agent_id, action_mask, team_manager_id
):
def __init__(self, obs, reward, group_reward, agent_id, action_mask, group_id):
self.team_reward: np.ndarray = team_reward
self.group_reward: np.ndarray = group_reward
self.team_manager_id: np.ndarray = team_manager_id
self.group_id: np.ndarray = group_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
team_manager_id = self.team_manager_id[agent_index]
group_id = self.group_id[agent_index]
team_reward=self.team_reward[agent_index],
group_reward=self.group_reward[agent_index],
team_manager_id=team_manager_id,
group_id=group_id,
)
def __iter__(self) -> Iterator[Any]:

return DecisionSteps(
obs=obs,
reward=np.zeros(0, dtype=np.float32),
team_reward=np.zeros(0, dtype=np.float32),
group_reward=np.zeros(0, dtype=np.float32),
team_manager_id=np.zeros(0, dtype=np.int32),
group_id=np.zeros(0, dtype=np.int32),
)

obs: List[np.ndarray]
reward: float
team_reward: float
group_reward: float
team_manager_id: int
group_id: int
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(
self, obs, reward, team_reward, interrupted, agent_id, team_manager_id
):
def __init__(self, obs, reward, group_reward, interrupted, agent_id, group_id):
self.team_reward: np.ndarray = team_reward
self.group_reward: np.ndarray = group_reward
self.team_manager_id: np.ndarray = team_manager_id
self.group_id: np.ndarray = group_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
team_manager_id = self.team_manager_id[agent_index]
group_id = self.group_id[agent_index]
team_reward=self.team_reward[agent_index],
group_reward=self.group_reward[agent_index],
team_manager_id=team_manager_id,
group_id=group_id,
)
def __iter__(self) -> Iterator[Any]:

return TerminalSteps(
obs=obs,
reward=np.zeros(0, dtype=np.float32),
team_reward=np.zeros(0, dtype=np.float32),
group_reward=np.zeros(0, dtype=np.float32),
team_manager_id=np.zeros(0, dtype=np.int32),
group_id=np.zeros(0, dtype=np.int32),
)

8
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xff\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05\x12\x13\n\x0bteam_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xf9\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x10\n\x08group_id\x18\x0e \x01(\x05\x12\x14\n\x0cgroup_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6,
name='group_id', full_name='communicator_objects.AgentInfoProto.group_id', index=6,
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,

name='team_reward', full_name='communicator_objects.AgentInfoProto.team_reward', index=7,
name='group_reward', full_name='communicator_objects.AgentInfoProto.group_reward', index=7,
number=15, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,

oneofs=[
],
serialized_start=132,
serialized_end=387,
serialized_end=381,
)
_AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO

12
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: builtin___int
team_reward = ... # type: builtin___float
group_id = ... # type: builtin___int
group_reward = ... # type: builtin___float
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[builtin___int] = None,
team_reward : typing___Optional[builtin___float] = None,
group_id : typing___Optional[builtin___int] = None,
group_reward : typing___Optional[builtin___float] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id",u"team_reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"group_id",u"group_reward",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id",u"team_reward",b"team_reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"group_id",b"group_id",u"group_reward",b"group_reward",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ...

28
ml-agents-envs/mlagents_envs/rpc_utils.py


[agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32
)
decision_team_rewards = np.array(
[agent_info.team_reward for agent_info in decision_agent_info_list],
decision_group_rewards = np.array(
[agent_info.group_reward for agent_info in decision_agent_info_list],
terminal_team_rewards = np.array(
[agent_info.team_reward for agent_info in terminal_agent_info_list],
terminal_group_rewards = np.array(
[agent_info.group_reward for agent_info in terminal_agent_info_list],
_raise_on_nan_and_inf(decision_team_rewards, "team_rewards")
_raise_on_nan_and_inf(terminal_team_rewards, "team_rewards")
_raise_on_nan_and_inf(decision_group_rewards, "group_rewards")
_raise_on_nan_and_inf(terminal_group_rewards, "group_rewards")
decision_team_managers = [
agent_info.team_manager_id for agent_info in decision_agent_info_list
]
terminal_team_managers = [
agent_info.team_manager_id for agent_info in terminal_agent_info_list
]
decision_group_id = [agent_info.group_id for agent_info in decision_agent_info_list]
terminal_group_id = [agent_info.group_id for agent_info in terminal_agent_info_list]
max_step = np.array(
[agent_info.max_step_reached for agent_info in terminal_agent_info_list],

DecisionSteps(
decision_obs_list,
decision_rewards,
decision_team_rewards,
decision_group_rewards,
decision_team_managers,
decision_group_id,
terminal_team_rewards,
terminal_group_rewards,
terminal_team_managers,
terminal_group_id,
),
)

4
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


repeated bool action_mask = 11;
reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
int32 team_manager_id = 14;
float team_reward = 15;
int32 group_id = 14;
float group_reward = 15;
}

2
com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta


fileFormatVersion: 2
guid: 06456db1475d84371b35bae4855db3c6
guid: cb62896b855f44d7f8a7c3fb96f7ab76
MonoImporter:
externalObjects: {}
serializedVersion: 2

13
com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class MultiAgentGroupIdCounter
{
static int s_Counter;
public static int GetGroupId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

165
com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs


using System;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.MultiAgent
{
public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable
{
int m_StepCount;
int m_GroupMaxStep;
readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId();
List<Agent> m_Agents = new List<Agent> { };
public BaseMultiAgentGroup()
{
Academy.Instance.PostAgentAct += _ManagerStep;
}
public void Dispose()
{
Academy.Instance.PostAgentAct -= _ManagerStep;
while (m_Agents.Count > 0)
{
UnregisterAgent(m_Agents[0]);
}
}
void _ManagerStep()
{
m_StepCount += 1;
if ((m_StepCount >= m_GroupMaxStep) && (m_GroupMaxStep > 0))
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.EpisodeInterrupted();
}
}
Reset();
}
}
/// <summary>
/// Register the agent to the MultiAgentGroup.
/// Registered agents will be able to receive group rewards from the MultiAgentGroup
/// and share observations during training.
/// </summary>
public virtual void RegisterAgent(Agent agent)
{
if (!m_Agents.Contains(agent))
{
agent.SetMultiAgentGroup(this);
m_Agents.Add(agent);
agent.UnregisterFromGroup += UnregisterAgent;
}
}
/// <summary>
/// Remove the agent from the MultiAgentGroup.
/// </summary>
public virtual void UnregisterAgent(Agent agent)
{
if (m_Agents.Contains(agent))
{
m_Agents.Remove(agent);
agent.UnregisterFromGroup -= UnregisterAgent;
}
}
/// <summary>
/// Get the ID of the MultiAgentGroup.
/// </summary>
/// <returns>
/// MultiAgentGroup ID.
/// </returns>
public int GetId()
{
return m_Id;
}
/// <summary>
/// Get list of all agents registered to this MultiAgentGroup.
/// </summary>
/// <returns>
/// List of agents belongs to the MultiAgentGroup.
/// </returns>
public List<Agent> GetRegisteredAgents()
{
return m_Agents;
}
/// <summary>
/// Add group reward for all agents under this MultiAgentGroup.
/// Disabled agent will not receive this reward.
/// </summary>
public void AddGroupReward(float reward)
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.AddGroupReward(reward);
}
}
}
/// <summary>
/// Set group reward for all agents under this MultiAgentGroup.
/// Disabled agent will not receive this reward.
/// </summary>
public void SetGroupReward(float reward)
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.SetGroupReward(reward);
}
}
}
/// <summary>
/// Returns the current step counter (within the current episode).
/// </summary>
/// <returns>
/// Current step count.
/// </returns>
public int StepCount
{
get { return m_StepCount; }
}
public int GroupMaxStep
{
get { return m_GroupMaxStep; }
}
public void SetGroupMaxStep(int maxStep)
{
m_GroupMaxStep = maxStep;
}
/// <summary>
/// End Episode for all agents under this MultiAgentGroup.
/// </summary>
public void EndGroupEpisode()
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.EndEpisode();
}
}
Reset();
}
void Reset()
{
m_StepCount = 0;
}
}
}

13
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class TeamManagerIdCounter
{
static int s_Counter;
public static int GetTeamManagerId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

/com.unity.ml-agents.extensions/Runtime/Teams.meta → /com.unity.ml-agents.extensions/Runtime/MultiAgent.meta

/com.unity.ml-agents/Runtime/ITeamManager.cs → /com.unity.ml-agents/Runtime/IMultiAgentGroup.cs

/com.unity.ml-agents/Runtime/ITeamManager.cs.meta → /com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta

/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta → /com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta

/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta → /com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta

正在加载...
取消
保存