浏览代码

add base team manager

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
aad7d342
共有 16 个文件被更改,包括 178 次插入13 次删除
  1. 17
      com.unity.ml-agents/Runtime/Agent.cs
  2. 1
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 39
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  4. 14
      ml-agents-envs/mlagents_envs/base_env.py
  5. 11
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  6. 6
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  7. 21
      ml-agents-envs/mlagents_envs/rpc_utils.py
  8. 1
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  9. 8
      com.unity.ml-agents.extensions/Runtime/Teams.meta
  10. 12
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs
  11. 11
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta
  12. 13
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs
  13. 11
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta
  14. 11
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta
  15. 15
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs

17
com.unity.ml-agents/Runtime/Agent.cs


/// </summary>
public int episodeId;
/// <summary>
/// Team Manager identifier.
/// </summary>
public int teamManagerId;
public void ClearActions()
{
storedActions.Clear();

/// </summary>
float[] m_LegacyHeuristicCache;
ITeamManager m_TeamManager;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html

new float[m_ActuatorManager.NumContinuousActions],
new int[m_ActuatorManager.NumDiscreteActions]
);
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
using (TimerStack.Instance.Scoped("RequestDecision"))
{

var actions = m_Brain?.DecideAction() ?? new ActionBuffers();
m_Info.CopyActions(actions);
m_ActuatorManager.UpdateActions(actions);
}
public void SetTeamManager(ITeamManager teamManager)
{
m_TeamManager = teamManager;
teamManager?.RegisterAgent(this);
}
}
}

1
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,
TeamManagerId = ai.teamManagerId,
};
if (ai.discreteActionMasks != null)

39
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null)
}));
}
#endregion

id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
teamManagerId_ = other.teamManagerId_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return observations_; }
}
/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private int teamManagerId_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int TeamManagerId {
get { return teamManagerId_; }
set {
teamManagerId_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoProto);

if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (TeamManagerId != other.TeamManagerId) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId != 0) {
output.WriteRawTag(112);
output.WriteInt32(TeamManagerId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 106: {
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 112: {
TeamManagerId = input.ReadInt32();
break;
}
}

14
ml-agents-envs/mlagents_envs/base_env.py


reward: float
agent_id: AgentId
action_mask: Optional[List[np.ndarray]]
team_manager_id: int
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask):
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
team_manager_id = self.team_manager_id[agent_index]
team_manager_id=team_manager_id,
)
def __iter__(self) -> Iterator[Any]:

reward=np.zeros(0, dtype=np.float32),
agent_id=np.zeros(0, dtype=np.int32),
action_mask=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

reward: float
interrupted: bool
agent_id: AgentId
team_manager_id: int
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id):
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
team_manager_id = self.team_manager_id[agent_index]
team_manager_id=team_manager_id,
)
def __iter__(self) -> Iterator[Any]:

reward=np.zeros(0, dtype=np.float32),
interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
team_manager_id=np.zeros(0, dtype=np.int32),
)

11
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xd1\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProtoJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6,
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=132,
serialized_end=341,
serialized_end=366,
)
_AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: builtin___int
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[builtin___int] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id"]) -> None: ...

21
ml-agents-envs/mlagents_envs/rpc_utils.py


[agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32
)
decision_team_managers = [
agent_info.team_manager_id for agent_info in decision_agent_info_list
]
terminal_team_managers = [
agent_info.team_manager_id for agent_info in terminal_agent_info_list
]
_raise_on_nan_and_inf(decision_rewards, "rewards")
_raise_on_nan_and_inf(terminal_rewards, "rewards")

action_mask = np.split(action_mask, indices, axis=1)
return (
DecisionSteps(
decision_obs_list, decision_rewards, decision_agent_id, action_mask
decision_obs_list,
decision_rewards,
decision_agent_id,
action_mask,
decision_team_managers,
TerminalSteps(terminal_obs_list, terminal_rewards, max_step, terminal_agent_id),
TerminalSteps(
terminal_obs_list,
terminal_rewards,
max_step,
terminal_agent_id,
terminal_team_managers,
),
)

1
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


repeated bool action_mask = 11;
reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
int32 team_manager_id = 14;
}

8
com.unity.ml-agents.extensions/Runtime/Teams.meta


fileFormatVersion: 2
guid: 77124df6c18c4f669052016b3116147e
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

12
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents
{
public interface ITeamManager
{
int GetId();
void RegisterAgent(Agent agent);
}
}

11
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta


fileFormatVersion: 2
guid: 8b061f82569af4ffba715297f77a95ab
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

13
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class TeamManagerIdCounter
{
static int s_Counter;
public static int GetTeamManagerId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

11
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta


fileFormatVersion: 2
guid: 06456db1475d84371b35bae4855db3c6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta


fileFormatVersion: 2
guid: b2967f9c3bd4449a98ad309085094769
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

15
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs


namespace Unity.MLAgents.Extensions.Teams
{
public class BaseTeamManager : ITeamManager
{
readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId();
public virtual void RegisterAgent(Agent agent) { }
public int GetId()
{
return m_Id;
}
}
}
正在加载...
取消
保存