浏览代码

Merge branch 'develop-base-teammanager' into develop-agentprocessor-teammanager

/develop/coma2/samenet
Ervin Teng 3 年前
当前提交
3fbed6dc
共有 12 个文件被更改,包括 260 次插入42 次删除
  1. 155
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs
  2. 3
      com.unity.ml-agents/Runtime/Academy.cs
  3. 46
      com.unity.ml-agents/Runtime/Agent.cs
  4. 1
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  5. 38
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  6. 6
      com.unity.ml-agents/Runtime/ITeamManager.cs
  7. 12
      ml-agents-envs/mlagents_envs/base_env.py
  8. 11
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  9. 6
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  10. 1
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
  11. 12
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs
  12. 11
      com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta

155
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs


using System;
using System.Collections.Generic;
public class BaseTeamManager : ITeamManager
public class BaseTeamManager : ITeamManager, IDisposable
int m_StepCount;
int m_TeamMaxStep;
List<Agent> m_Agents = new List<Agent> { };
public virtual void RegisterAgent(Agent agent) { }
public BaseTeamManager()
{
Academy.Instance.PostAgentAct += _ManagerStep;
}
public void Dispose()
{
Academy.Instance.PostAgentAct -= _ManagerStep;
}
void _ManagerStep()
{
m_StepCount += 1;
if ((m_StepCount >= m_TeamMaxStep) && (m_TeamMaxStep > 0))
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.EpisodeInterrupted();
}
}
Reset();
}
}
/// <summary>
/// Register the agent to the TeamManager.
/// Registered agents will be able to receive team rewards from the TeamManager
/// and share observations during training.
/// </summary>
public virtual void RegisterAgent(Agent agent)
{
if (!m_Agents.Contains(agent))
{
m_Agents.Add(agent);
}
}
/// <summary>
/// Remove the agent from the TeamManager.
/// </summary>
public virtual void UnregisterAgent(Agent agent)
{
if (m_Agents.Contains(agent))
{
m_Agents.Remove(agent);
}
}
/// <summary>
/// Get the ID of the TeamManager.
/// </summary>
/// <returns>
/// TeamManager ID.
/// </returns>
}
/// <summary>
/// Get list of all agents registered to this TeamManager.
/// </summary>
/// <returns>
/// List of agents belongs to the TeamManager.
/// </returns>
public List<Agent> GetRegisteredAgents()
{
return m_Agents;
}
/// <summary>
/// Add team reward for all agents under this Teammanager.
/// Disabled agent will not receive this reward.
/// </summary>
public void AddTeamReward(float reward)
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.AddTeamReward(reward);
}
}
}
/// <summary>
/// Set team reward for all agents under this Teammanager.
/// Disabled agent will not receive this reward.
/// </summary>
public void SetTeamReward(float reward)
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.SetTeamReward(reward);
}
}
}
/// <summary>
/// Returns the current step counter (within the current episode).
/// </summary>
/// <returns>
/// Current step count.
/// </returns>
public int StepCount
{
get { return m_StepCount; }
}
public int TeamMaxStep
{
get { return m_TeamMaxStep; }
}
public void SetTeamMaxStep(int maxStep)
{
m_TeamMaxStep = maxStep;
}
/// <summary>
/// End Episode for all agents under this TeamManager.
/// </summary>
public void EndTeamEpisode()
{
foreach (var agent in m_Agents)
{
if (agent.enabled)
{
agent.EndEpisode();
}
}
Reset();
}
/// <summary>
/// End Episode for all agents under this TeamManager.
/// </summary>
public virtual void OnTeamEpisodeBegin()
{
}
void Reset()
{
m_StepCount = 0;
OnTeamEpisodeBegin();
}
}
}

3
com.unity.ml-agents/Runtime/Academy.cs


// This will mark the Agent as Done if it has reached its maxSteps.
internal event Action AgentIncrementStep;
internal event Action PostAgentAct;
/// <summary>
/// Signals to all of the <see cref="Agent"/>s that their step is about to begin.

{
AgentAct?.Invoke();
}
PostAgentAct?.Invoke();
}
}

46
com.unity.ml-agents/Runtime/Agent.cs


public float reward;
/// <summary>
/// The current team reward received by the agent.
/// </summary>
public float teamReward;
/// <summary>
/// Whether the agent is done or not.
/// </summary>
public bool done;

/// Additionally, the magnitude of the reward should not exceed 1.0
float m_Reward;
/// Represents the team reward the agent accumulated during the current step.
float m_TeamReward;
/// Keeps track of the cumulative reward in this episode.
float m_CumulativeReward;

new int[m_ActuatorManager.NumDiscreteActions]
);
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_Initialized = false;
}
void OnDestroy()
{
if (m_TeamManager != null)
{
m_TeamManager.UnregisterAgent(this);
}
}
void NotifyAgentDone(DoneReason doneReason)
{
if (m_Info.done)

}
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.teamReward = m_TeamReward;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId();
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

}
m_Reward = 0f;
m_TeamReward = 0f;
m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;

m_CumulativeReward += increment;
}
internal void SetTeamReward(float reward)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward));
#endif
m_TeamReward = reward;
}
internal void AddTeamReward(float increment)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddTeamReward));
#endif
m_TeamReward += increment;
}
/// <summary>
/// Retrieves the episode reward for the Agent.
/// </summary>

m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.teamReward = m_TeamReward;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId();
using (TimerStack.Instance.Scoped("RequestDecision"))
{

{
SendInfoToBrain();
m_Reward = 0f;
m_TeamReward = 0f;
m_RequestDecision = false;
}
}

public void SetTeamManager(ITeamManager teamManager)
{
if (m_TeamManager != null)
{
m_TeamManager.UnregisterAgent(this);
}
m_TeamManager = teamManager;
teamManager?.RegisterAgent(this);
}

1
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
TeamReward = ai.teamReward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,

38
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIv8BCg5B",
"X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
"X2lkGA4gASgFEhMKC3RlYW1fcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQI",
"AxAESgQIBBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50",
"cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId", "TeamReward" }, null, null, null)
}));
}
#endregion

actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
teamManagerId_ = other.teamManagerId_;
teamReward_ = other.teamReward_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "team_reward" field.</summary>
public const int TeamRewardFieldNumber = 15;
private float teamReward_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public float TeamReward {
get { return teamReward_; }
set {
teamReward_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoProto);

if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (TeamManagerId != other.TeamManagerId) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TeamReward, other.TeamReward)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (TeamReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TeamReward);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(112);
output.WriteInt32(TeamManagerId);
}
if (TeamReward != 0F) {
output.WriteRawTag(125);
output.WriteFloat(TeamReward);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
}
if (TeamReward != 0F) {
size += 1 + 4;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

observations_.Add(other.observations_);
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
}
if (other.TeamReward != 0F) {
TeamReward = other.TeamReward;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 112: {
TeamManagerId = input.ReadInt32();
break;
}
case 125: {
TeamReward = input.ReadFloat();
break;
}
}

6
com.unity.ml-agents/Runtime/ITeamManager.cs


{
public interface ITeamManager
{
string GetId();
int GetId();
// TODO not sure this is all the info we need, maybe pass a class/struct instead.
void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors);
void UnregisterAgent(Agent agent);
}
}

12
ml-agents-envs/mlagents_envs/base_env.py


this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id):
def __init__(
self, obs, reward, team_reward, agent_id, action_mask, team_manager_id
):
self.team_reward: np.ndarray = team_reward
self.agent_id: np.ndarray = agent_id
self.team_manager_id: Optional[List[str]] = team_manager_id
self.action_mask: Optional[List[np.ndarray]] = action_mask

return DecisionSteps(
obs=obs,
reward=np.zeros(0, dtype=np.float32),
team_reward=np.zeros(0, dtype=np.float32),
agent_id=np.zeros(0, dtype=np.int32),
action_mask=None,
team_manager_id=np.zeros(0, dtype=np.int32),

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id):
def __init__(
self, obs, reward, team_reward, interrupted, agent_id, team_manager_id
):
self.team_reward: np.ndarray = team_reward
self.interrupted: np.ndarray = interrupted
self.agent_id: np.ndarray = agent_id
self.team_manager_id: np.ndarray = team_manager_id

return TerminalSteps(
obs=obs,
reward=np.zeros(0, dtype=np.float32),
team_reward=np.zeros(0, dtype=np.float32),
interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
team_manager_id=np.zeros(0, dtype=np.int32),

11
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xff\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05\x12\x13\n\x0bteam_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_reward', full_name='communicator_objects.AgentInfoProto.team_reward', index=7,
number=15, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=132,
serialized_end=366,
serialized_end=387,
)
_AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: builtin___int
team_reward = ... # type: builtin___float
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[builtin___int] = None,
team_reward : typing___Optional[builtin___float] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id",u"team_reward"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id",u"team_reward",b"team_reward"]) -> None: ...

1
protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto


reserved 12; // deprecated CustomObservationProto custom_observation = 12;
repeated ObservationProto observations = 13;
int32 team_manager_id = 14;
float team_reward = 15;
}

12
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents
{
public interface ITeamManager
{
int GetId();
void RegisterAgent(Agent agent);
}
}

11
com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta


fileFormatVersion: 2
guid: 8b061f82569af4ffba715297f77a95ab
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存