浏览代码

add base team manager

/develop/superpush/int
Ruo-Ping Dong 3 年前
当前提交
6f0bb2a4
共有 15 个文件被更改,包括 96 次插入112 次删除
  1. 25
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs
  2. 32
      com.unity.ml-agents/Runtime/Agent.cs
  3. 1
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  4. 26
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  5. 24
      ml-agents-envs/mlagents_envs/base_env.py
  6. 6
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  7. 5
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  8. 23
      ml-agents-envs/mlagents_envs/rpc_utils.py
  9. 2
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  10. 12
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs
  11. 11
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta
  12. 13
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs
  13. 11
      com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta
  14. 14
      com.unity.ml-agents/Runtime/ITeamManager.cs
  15. 3
      com.unity.ml-agents/Runtime/ITeamManager.cs.meta

25
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
private readonly string m_Id = System.Guid.NewGuid().ToString();
readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId();
public virtual void RegisterAgent(Agent agent)
{
}
public virtual void RegisterAgent(Agent agent) { }
public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
{
// Possible implementation - save reference to Agent's IPolicy so that we can repeatedly
// call IPolicy.RequestDecision on behalf of the Agent after it's dead
// If so, we'll need dummy sensor impls with the same shape as the originals.
agent.SendDoneToTrainer();
}
public virtual void AddTeamReward(float reward)
{
}
public string GetId()
public int GetId()
{
return m_Id;
}

32
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Team Manager identifier.
/// </summary>
public string teamManagerId;
public int teamManagerId;
public void ClearActions()
{

/// </summary>
float[] m_LegacyActionCache;
private ITeamManager m_TeamManager;
ITeamManager m_TeamManager;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.

new int[m_ActuatorManager.NumDiscreteActions]
);
if (m_TeamManager != null)
{
m_Info.teamManagerId = m_TeamManager.GetId();
}
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

}
}
// Request the last decision with no callbacks
if (m_TeamManager != null)
{
// Send final observations to TeamManager if it exists.
// The TeamManager is responsible to keeping track of the Agent after it's
// done, including propagating any "posthumous" rewards.
m_TeamManager.OnAgentDone(this, doneReason, sensors);
}
else
{
SendDoneToTrainer();
}
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
ResetSensors();
// We also have to write any to any DemonstationStores so that they get the "done" flag.
foreach (var demoWriter in DemonstrationWriters)

m_Info.storedActions.Clear();
}
public void SendDoneToTrainer()
{
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
ResetSensors();
}
/// <summary>
/// Updates the Model assigned to this Agent instance.

m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
using (TimerStack.Instance.Scoped("RequestDecision"))
{

public void SetTeamManager(ITeamManager teamManager)
{
m_TeamManager = teamManager;
m_Info.teamManagerId = teamManager?.GetId();
teamManager?.RegisterAgent(this);
}
}

1
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,
TeamManagerId = ai.teamManagerId,
};
if (ai.discreteActionMasks != null)

26
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,

/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private string teamManagerId_ = "";
private int teamManagerId_;
public string TeamManagerId {
public int TeamManagerId {
teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
teamManagerId_ = value;
}
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId.Length != 0) {
output.WriteRawTag(114);
output.WriteString(TeamManagerId);
if (TeamManagerId != 0) {
output.WriteRawTag(112);
output.WriteInt32(TeamManagerId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId);
if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId.Length != 0) {
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);

observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 114: {
TeamManagerId = input.ReadString();
case 112: {
TeamManagerId = input.ReadInt32();
break;
}
}

24
ml-agents-envs/mlagents_envs/base_env.py


reward: float
agent_id: AgentId
action_mask: Optional[List[np.ndarray]]
team_manager_id: Optional[str]
team_manager_id: int
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id=None):
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id):
self.team_manager_id: Optional[List[str]] = team_manager_id
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return DecisionStep(
obs=agent_obs,
reward=self.reward[agent_index],

reward=np.zeros(0, dtype=np.float32),
agent_id=np.zeros(0, dtype=np.int32),
action_mask=None,
team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

reward: float
interrupted: bool
agent_id: AgentId
team_manager_id: Optional[str]
team_manager_id: int
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id=None):
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self.team_manager_id: Optional[List[str]] = team_manager_id
@property
def agent_id_to_index(self) -> Dict[AgentId, int]:

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return TerminalStep(
obs=agent_obs,
reward=self.reward[agent_index],

reward=np.zeros(0, dtype=np.float32),
interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

6
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\tJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6,
number=14, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),

5
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


from typing import (
Iterable as typing___Iterable,
Optional as typing___Optional,
Text as typing___Text,
)
from typing_extensions import (

max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: typing___Text
team_manager_id = ... # type: builtin___int
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[typing___Text] = None,
team_manager_id : typing___Optional[builtin___int] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

23
ml-agents-envs/mlagents_envs/rpc_utils.py


decision_rewards = np.array(
[agent_info.reward for agent_info in decision_agent_info_list], dtype=np.float32
)
decision_team_manager = [
agent_info.team_manager_id
for agent_info in decision_agent_info_list
if agent_info.team_manager_id is not None
]
if len(decision_team_manager) == 0:
decision_team_manager = None
terminal_team_manager = [
agent_info.team_manager_id
for agent_info in terminal_agent_info_list
if agent_info.team_manager_id is not None
decision_team_managers = [
agent_info.team_manager_id for agent_info in decision_agent_info_list
]
terminal_team_managers = [
agent_info.team_manager_id for agent_info in terminal_agent_info_list
if len(terminal_team_manager) == 0:
terminal_team_manager = None
_raise_on_nan_and_inf(decision_rewards, "rewards")
_raise_on_nan_and_inf(terminal_rewards, "rewards")

decision_rewards,
decision_agent_id,
action_mask,
decision_team_manager,
decision_team_managers,
),
TerminalSteps(
terminal_obs_list,

terminal_team_manager,
terminal_team_managers,
),
)

2
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


repeated bool action_mask = 11;
reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
string team_manager_id = 14;
int32 team_manager_id = 14;
}

12
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents
{
public interface ITeamManager
{
int GetId();
void RegisterAgent(Agent agent);
}
}

11
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta


fileFormatVersion: 2
guid: 8b061f82569af4ffba715297f77a95ab
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

13
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs


using System.Threading;
namespace Unity.MLAgents
{
internal static class TeamManagerIdCounter
{
static int s_Counter;
public static int GetTeamManagerId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}

11
com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta


fileFormatVersion: 2
guid: 06456db1475d84371b35bae4855db3c6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

14
com.unity.ml-agents/Runtime/ITeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents
{
public interface ITeamManager
{
string GetId();
void RegisterAgent(Agent agent);
// TODO not sure this is all the info we need, maybe pass a class/struct instead.
void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors);
}
}

3
com.unity.ml-agents/Runtime/ITeamManager.cs.meta


fileFormatVersion: 2
guid: 75810d91665e4477977eb78c9b15aeb3
timeCreated: 1610057818
正在加载...
取消
保存