浏览代码

change manager id from string to int

/develop/teammanager
Ruo-Ping Dong 3 年前
当前提交
6d1dcb15
共有 15 个文件被更改,包括 90 次插入73 次删除
  1. 9
      com.unity.ml-agents.extensions/Runtime/Teams.meta
  2. 11
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs
  3. 12
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta
  4. 12
      com.unity.ml-agents/Runtime/Agent.cs
  5. 6
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  6. 26
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  7. 2
      com.unity.ml-agents/Runtime/ITeamManager.cs
  8. 23
      ml-agents-envs/mlagents_envs/base_env.py
  9. 6
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  10. 5
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  11. 23
      ml-agents-envs/mlagents_envs/rpc_utils.py
  12. 2
      ml-agents/mlagents/trainers/agent_processor.py
  13. 2
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  14. 13
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs
  15. 11
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta

9
com.unity.ml-agents.extensions/Runtime/Teams.meta


fileFormatVersion: 2
fileFormatVersion: 2
timeCreated: 1610064454
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs


{
public class BaseTeamManager : ITeamManager
{
private readonly string m_Id = System.Guid.NewGuid().ToString();
public virtual void RegisterAgent(Agent agent)
{
}
readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId();
public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
{

agent.SendDoneToTrainer();
}
public string GetId()
public virtual void RegisterAgent(Agent agent) { }
public int GetId()
{
return m_Id;
}

12
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta


fileFormatVersion: 2
fileFormatVersion: 2
timeCreated: 1610064493
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

12
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Team Manager identifier.
/// </summary>
public string teamManagerId;
public int teamManagerId;
public void ClearActions()
{

/// </summary>
float[] m_LegacyActionCache;
private ITeamManager m_TeamManager;
ITeamManager m_TeamManager;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.

new int[m_ActuatorManager.NumDiscreteActions]
);
if (m_TeamManager != null)
{
m_Info.teamManagerId = m_TeamManager.GetId();
}
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
using (TimerStack.Instance.Scoped("RequestDecision"))
{

public void SetTeamManager(ITeamManager teamManager)
{
m_TeamManager = teamManager;
m_Info.teamManagerId = teamManager?.GetId();
teamManager?.RegisterAgent(this);
}
}

6
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,
TeamManagerId = ai.teamManagerId,
}
if (ai.teamManagerId != null)
{
agentInfoProto.TeamManagerId = ai.teamManagerId;
}
return agentInfoProto;

26
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,

/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private string teamManagerId_ = "";
private int teamManagerId_;
public string TeamManagerId {
public int TeamManagerId {
teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
teamManagerId_ = value;
}
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId.Length != 0) {
output.WriteRawTag(114);
output.WriteString(TeamManagerId);
if (TeamManagerId != 0) {
output.WriteRawTag(112);
output.WriteInt32(TeamManagerId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId);
if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId.Length != 0) {
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);

observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 114: {
TeamManagerId = input.ReadString();
case 112: {
TeamManagerId = input.ReadInt32();
break;
}
}

2
com.unity.ml-agents/Runtime/ITeamManager.cs


{
public interface ITeamManager
{
string GetId();
int GetId();
void RegisterAgent(Agent agent);
// TODO not sure this is all the info we need, maybe pass a class/struct instead.

23
ml-agents-envs/mlagents_envs/base_env.py


reward: float
agent_id: AgentId
action_mask: Optional[List[np.ndarray]]
team_manager_id: Optional[str]
team_manager_id: int
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id=None):
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id):
self.team_manager_id: Optional[List[str]] = team_manager_id
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return DecisionStep(
obs=agent_obs,
reward=self.reward[agent_index],

reward=np.zeros(0, dtype=np.float32),
agent_id=np.zeros(0, dtype=np.int32),
action_mask=None,
team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

reward: float
interrupted: bool
agent_id: AgentId
team_manager_id: Optional[str]
team_manager_id: int
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id=None):
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
self.team_manager_id: Optional[List[str]] = team_manager_id

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return TerminalStep(
obs=agent_obs,
reward=self.reward[agent_index],

reward=np.zeros(0, dtype=np.float32),
interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

6
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\tJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6,
number=14, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),

5
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


from typing import (
Iterable as typing___Iterable,
Optional as typing___Optional,
Text as typing___Text,
)
from typing_extensions import (

max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: typing___Text
team_manager_id = ... # type: builtin___int
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[typing___Text] = None,
team_manager_id : typing___Optional[builtin___int] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

23
ml-agents-envs/mlagents_envs/rpc_utils.py


decision_rewards = np.array(
[agent_info.reward for agent_info in decision_agent_info_list], dtype=np.float32
)
decision_team_manager = [
agent_info.team_manager_id
for agent_info in decision_agent_info_list
if agent_info.team_manager_id is not None
]
if len(decision_team_manager) == 0:
decision_team_manager = None
terminal_team_manager = [
agent_info.team_manager_id
for agent_info in terminal_agent_info_list
if agent_info.team_manager_id is not None
decision_team_managers = [
agent_info.team_manager_id for agent_info in decision_agent_info_list
]
terminal_team_managers = [
agent_info.team_manager_id for agent_info in terminal_agent_info_list
if len(terminal_team_manager) == 0:
terminal_team_manager = None
_raise_on_nan_and_inf(decision_rewards, "rewards")
_raise_on_nan_and_inf(terminal_rewards, "rewards")

decision_rewards,
decision_agent_id,
action_mask,
decision_team_manager,
decision_team_managers,
),
TerminalSteps(
terminal_obs_list,

terminal_team_manager,
terminal_team_managers,
),
)

2
ml-agents/mlagents/trainers/agent_processor.py


) -> None:
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
if stored_decision_step is not None:
if step.team_manager_id is not None:
if step.team_manager_id >= 0:
self.last_group_obs[step.team_manager_id][
global_id
] = stored_decision_step.obs

2
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


repeated bool action_mask = 11;
reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
string team_manager_id = 14;
int32 team_manager_id = 14;
}

13
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class TeamManagerIdCounter
{
static int s_Counter;
public static int GetTeamManagerId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

11
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta


fileFormatVersion: 2
guid: 06456db1475d84371b35bae4855db3c6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存