浏览代码

Refactor ICommunicator API (#2675)

- Push (almost) all references to protobuf objects into the RpcCommunicator.
- Simplify the passing around of Agents and Agent Infos.
- Delete all references to the Batcher.
- Simplify the Environment Step by removing all of the reset and message counting logic.
- Finishes MLA-27 and MLA-28
/develop-gpu-test
GitHub 5 年前
当前提交
2d92a49b
共有 28 个文件被更改,包括 755 次插入865 次删除
  1. 2
      UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs
  2. 15
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
  3. 26
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
  4. 33
      UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
  5. 1
      UnitySDK/Assets/ML-Agents/Editor/Tests/TimerTest.cs
  6. 2
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  7. 202
      UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
  8. 46
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  9. 24
      UnitySDK/Assets/ML-Agents/Scripts/Brain.cs
  10. 41
      UnitySDK/Assets/ML-Agents/Scripts/BrainParameters.cs
  11. 1
      UnitySDK/Assets/ML-Agents/Scripts/Demonstration.cs
  12. 100
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
  13. 305
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
  14. 27
      UnitySDK/Assets/ML-Agents/Scripts/HeuristicBrain.cs
  15. 123
      UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
  16. 23
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
  17. 49
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  18. 14
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
  19. 10
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
  20. 26
      UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs
  21. 5
      UnitySDK/Assets/ML-Agents/Scripts/PlayerBrain.cs
  22. 25
      UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs
  23. 21
      UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
  24. 1
      UnitySDK/UnitySDK.sln.DotSettings
  25. 13
      UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs.meta
  26. 181
      UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs
  27. 13
      UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta
  28. 291
      UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs

2
UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs


reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
var brainParamsProto = BrainParametersProto.Parser.ParseDelimitedFrom(reader);
var brainParameters = new BrainParameters(brainParamsProto);
var brainParameters = brainParamsProto.ToBrainParameters();
reader.Close();

15
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs


using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using UnityEngine;
using System.Reflection;

}
}
private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
private List<Agent> GetFakeAgentInfos()
var infoA = new AgentInfo();
var infoB = new AgentInfo();
return new Dictionary<Agent, AgentInfo>(){{agentA, infoA}, {agentB, infoB}};
return new List<Agent> {agentA, agentB};
}
[Test]

var applier = new ContinuousActionOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agents = agentInfos;
var agent = agents[0] as TestAgent;
Assert.NotNull(agent);

var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agents = agentInfos;
var agent = agents[0] as TestAgent;
Assert.NotNull(agent);

var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agents = agentInfos;
var agent = agents[0] as TestAgent;
Assert.NotNull(agent);

var applier = new ValueEstimateApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agents = agentInfos;
var agent = agents[0] as TestAgent;
Assert.NotNull(agent);

26
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs


{
}
private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
private static IEnumerable<Agent> GetFakeAgentInfos()
var infoA = new AgentInfo()
var infoA = new AgentInfo
stackedVectorObservation = (new[] {1f, 2f, 3f}).ToList(),
stackedVectorObservation = new[] {1f, 2f, 3f}.ToList(),
actionMasks = null,
actionMasks = null
var infoB = new AgentInfo()
var infoB = new AgentInfo
stackedVectorObservation = (new[] {4f, 5f, 6f}).ToList(),
memories = (new[] {1f, 1f, 1f}).ToList(),
stackedVectorObservation = new[] {4f, 5f, 6f}.ToList(),
memories = new[] {1f, 1f, 1f}.ToList(),
agentA.Info = infoA;
agentB.Info = infoB;
return new Dictionary<Agent, AgentInfo>(){{agentA, infoA}, {agentB, infoB}};
return new List<Agent> {agentA, agentB};
}
[Test]

[Test]
public void GenerateVectorObservation()
{
var inputTensor = new TensorProxy()
var inputTensor = new TensorProxy
{
shape = new long[] {2, 3}
};

[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new TensorProxy()
var inputTensor = new TensorProxy
{
shape = new long[] {2, 5}
};

[Test]
public void GeneratePreviousActionInput()
{
var inputTensor = new TensorProxy()
var inputTensor = new TensorProxy
{
shape = new long[] {2, 2},
valueType = TensorProxy.TensorType.Integer

[Test]
public void GenerateActionMaskInput()
{
var inputTensor = new TensorProxy()
var inputTensor = new TensorProxy
{
shape = new long[] {2, 5},
valueType = TensorProxy.TensorType.FloatingPoint

33
UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs


public override void AcademyReset()
{
}
public override void AcademyStep()

protected override void DecideAction()
{
numberOfCallsToDecideAction++;
m_AgentInfos.Clear();
m_Agents.Clear();
}
}

//This will call the method even though it is private
var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
Assert.AreEqual(1, aca.initializeAcademyCalls);
Assert.AreEqual(0, aca.GetEpisodeCount());
Assert.AreEqual(0, aca.GetStepCount());

agentEnableMethod?.Invoke(agent2, new object[] { aca });
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
agentEnableMethod?.Invoke(agent1, new object[] { aca });
Assert.AreEqual(false, agent1.IsDone());

var aca = acaGo.GetComponent<TestAcademy>();
var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
var academyStepMethod = typeof(Academy).GetMethod("EnvironmentStep",
BindingFlags.Instance | BindingFlags.NonPublic);

{
numberReset += 1;
}
academyStepMethod?.Invoke(aca, new object[] { });
academyStepMethod?.Invoke(aca, new object[] {});
}
}

agent2.GiveBrain(brain);
agentEnableMethod?.Invoke(agent1, new object[] { aca });
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
var academyStepMethod = typeof(Academy).GetMethod(
"EnvironmentStep", BindingFlags.Instance | BindingFlags.NonPublic);

requestAction += 1;
agent2.RequestAction();
}
academyStepMethod?.Invoke(aca, new object[] { });
academyStepMethod?.Invoke(aca, new object[] {});
}
}
}

var aca = acaGo.GetComponent<TestAcademy>();
var academyInitializeMethod = typeof(Academy).GetMethod(
"InitializeEnvironment", BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
var academyStepMethod = typeof(Academy).GetMethod(
"EnvironmentStep", BindingFlags.Instance | BindingFlags.NonPublic);

}
stepsSinceReset += 1;
academyStepMethod.Invoke((object)aca, new object[] { });
academyStepMethod.Invoke(aca, new object[] {});
}
}

agent2.GiveBrain(brain);
agentEnableMethod?.Invoke(agent2, new object[] { aca });
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
var numberAgent1Reset = 0;
var numberAgent2Reset = 0;

agent2StepSinceReset += 1;
//Agent 1 is only initialized at step 2
if (i < 2)
{ }
academyStepMethod?.Invoke(aca, new object[] { });
{}
academyStepMethod?.Invoke(aca, new object[] {});
}
}
}

agent2.GiveBrain(brain);
agentEnableMethod?.Invoke(agent2, new object[] { aca });
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
agentEnableMethod?.Invoke(agent1, new object[] { aca });
var agent1ResetOnDone = 0;

}
academyStepMethod?.Invoke(aca, new object[] { });
academyStepMethod?.Invoke(aca, new object[] {});
}
}

agent2.GiveBrain(brain);
agentEnableMethod?.Invoke(agent2, new object[] { aca });
academyInitializeMethod?.Invoke(aca, new object[] { });
academyInitializeMethod?.Invoke(aca, new object[] {});
agentEnableMethod?.Invoke(agent1, new object[] { aca });

Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f);
academyStepMethod?.Invoke(aca, new object[] { });
academyStepMethod?.Invoke(aca, new object[] {});
agent1.AddReward(10f);
if ((i % 21 == 0) && (i > 0))

1
UnitySDK/Assets/ML-Agents/Editor/Tests/TimerTest.cs


{
public class TimerTests
{
[Test]
public void TestNested()
{

2
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


public Transform ground;
public bool detectTargets;
public bool targetIsStatic = false;
public bool targetIsStatic;
public bool respawnTargetWhenTouched;
public float targetSpawnRadius;

202
UnitySDK/Assets/ML-Agents/Scripts/Academy.cs


/// The mode is determined by the presence or absence of a Communicator. In
/// the presence of a communicator, the academy is run in training mode where
/// the states and observations of each agent are sent through the
/// communicator. In the absence of a communciator, the academy is run in
/// communicator. In the absence of a communicator, the academy is run in
/// inference mode where the agent behavior is determined by the brain
/// attached to it (which may be internal, heuristic or player).
/// </remarks>

private Vector3 m_OriginalGravity;
/// Temporary storage for global fixedDeltaTime value
/// Used to restore oringal value when deriving Academy modifies it
/// Used to restore original value when deriving Academy modifies it
/// Used to restore oringal value when deriving Academy modifies it
/// Used to restore original value when deriving Academy modifies it
private float m_OriginalMaximumDeltaTime;
// Fields provided in the Inspector

/// </summary>
/// <remarks>
/// Default reset parameters are specified in the academy Editor, and can
/// be modified when training with an external Brain by passinga config
/// be modified when training with an external Brain by passing a config
/// dictionary at reset.
/// </remarks>
[SerializeField]

// Fields not provided in the Inspector.
/// Boolean flag indicating whether a communicator is accessible by the
/// environment. This also specifies whether the environment is in
/// Training or Inference mode.
bool m_IsCommunicatorOn;
/// Keeps track of the id of the last communicator message received.
/// Remains 0 if there are no communicators. Is used to ensure that
/// the same message is not used multiple times.
private ulong m_LastCommunicatorMessageNumber;
/// <summary>
/// Returns whether or not the communicator is on.
/// </summary>
/// <returns>
/// <c>true</c>, if communicator is on, <c>false</c> otherwise.
/// </returns>
bool IsCommunicatorOn
{
get { return m_Communicator != null; }
}
/// If true, the Academy will use inference settings. This field is
/// initialized in <see cref="Awake"/> depending on the presence

/// each time the environment is reset.
int m_EpisodeCount;
/// The number of steps completed within the current episide. Incremented
/// The number of steps completed within the current episode. Incremented
/// each time a step is taken in the environment. Is reset to 0 during
/// <see cref="AcademyReset"/>.
int m_StepCount;

/// engine settings at the next environment step.
bool m_ModeSwitched;
/// Pointer to the batcher currently in use by the Academy.
Batcher m_BrainBatcher;
/// Pointer to the communicator currently in use by the Academy.
ICommunicator m_Communicator;
// Flag used to keep track of the first time the Academy is reset.
bool m_FirstAcademyReset;

// they have requested a decision.
public event System.Action AgentAct;
// Sigals to all the agents each time the Academy force resets.
// Signals to all the agents each time the Academy force resets.
/// Monobehavior function called at the very beginning of environment
/// MonoBehavior function called at the very beginning of environment
/// creation. Academy uses this time to initialize internal data
/// structures, initialize the environment and check for the existence
/// of a communicator.

}
// Used to read Python-provided environment parameters
private int ReadArgs()
private static int ReadArgs()
{
var args = System.Environment.GetCommandLineArgs();
var inputPort = "";

m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;
InitializeAcademy();
ICommunicator communicator;
var controlledBrains = broadcastHub.brainsToControl.Where(x => x != null).ToList();

communicator = new RpcCommunicator(
new CommunicatorParameters
m_Communicator = new RpcCommunicator(
new CommunicatorInitParameters
// and if Unity is in Editor mode
// and if Unity is in Editor mode
// If there are not, there is no need for a communicator and it is set
// to null
communicator = null;
if (controlledBrains.ToList().Count > 0)
m_Communicator = null;
if (controlledBrains.Any())
communicator = new RpcCommunicator(
new CommunicatorParameters
m_Communicator = new RpcCommunicator(
new CommunicatorInitParameters
{
port = 5005
});

m_BrainBatcher = new Batcher(communicator);
if (communicator != null)
foreach (var trainingBrain in controlledBrains)
trainingBrain.SetCommunicator(m_Communicator);
}
foreach (var trainingBrain in controlledBrains)
{
trainingBrain.SetBatcher(m_BrainBatcher);
}
m_IsCommunicatorOn = true;
if (m_Communicator != null)
{
m_Communicator.QuitCommandReceived += OnQuitCommandReceived;
m_Communicator.ResetCommandReceived += OnResetCommand;
m_Communicator.RLInputReceived += OnRLInputReceived;
var academyParameters =
new CommunicatorObjects.UnityRLInitializationOutputProto();
academyParameters.Name = gameObject.name;
academyParameters.Version = k_ApiVersion;
foreach (var brain in controlledBrains)
{
var bp = brain.brainParameters;
academyParameters.BrainParameters.Add(
bp.ToProto(brain.name, true));
}
academyParameters.EnvironmentParameters =
new CommunicatorObjects.EnvironmentParametersProto();
foreach (var key in resetParameters.Keys)
{
academyParameters.EnvironmentParameters.FloatParameters.Add(
key, resetParameters[key]
);
}
var pythonParameters = m_BrainBatcher.SendAcademyParameters(academyParameters);
Random.InitState(pythonParameters.Seed);
var unityRLInitParameters = m_Communicator.Initialize(
new CommunicatorInitParameters
{
version = k_ApiVersion,
name = gameObject.name,
brains = controlledBrains,
environmentResetParameters = new EnvironmentResetParameters
{
resetParameters = resetParameters,
customResetParameters = customResetParameters
}
}, broadcastHub);
Random.InitState(unityRLInitParameters.seed);
communicator = null;
m_BrainBatcher = new Batcher(null);
m_IsCommunicatorOn = false;
foreach (var trainingBrain in controlledBrains)
m_Communicator = null;
foreach (var brain in controlledBrains)
trainingBrain.SetBatcher(null);
brain.SetCommunicator(null);
}
}
}

// in inference mode.
m_IsInference = !m_IsCommunicatorOn;
SetIsInference(!IsCommunicatorOn);
BrainDecideAction += () => { };
DestroyAction += () => { };
AgentSetStatus += (i) => { };
AgentResetIfDone += () => { };
AgentSendState += () => { };
AgentAct += () => { };
AgentForceReset += () => { };
BrainDecideAction += () => {};
DestroyAction += () => {};
AgentSetStatus += i => {};
AgentResetIfDone += () => {};
AgentSendState += () => {};
AgentAct += () => {};
AgentForceReset += () => {};
// Configure the environment using the configurations provided by
// the developer in the Editor.
SetIsInference(!m_BrainBatcher.GetIsTraining());
private void UpdateResetParameters()
static void OnQuitCommandReceived()
{
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
Application.Quit();
}
private void OnResetCommand(EnvironmentResetParameters newResetParameters)
{
UpdateResetParameters(newResetParameters);
ForcedFullReset();
}
void OnRLInputReceived(UnityRLInputParameters inputParams)
{
m_IsInference = !inputParams.isTraining;
}
private void UpdateResetParameters(EnvironmentResetParameters newResetParameters)
var newResetParameters = m_BrainBatcher?.GetEnvironmentParameters();
if (newResetParameters != null)
if (newResetParameters.resetParameters != null)
foreach (var kv in newResetParameters.FloatParameters)
foreach (var kv in newResetParameters.resetParameters)
customResetParameters = newResetParameters.CustomResetParameters;
customResetParameters = newResetParameters.customResetParameters;
}
/// <summary>

// This signals to the academy that at the next environment step
// the engine configurations need updating to the respective mode
// (i.e. training vs inference) configuraiton.
// (i.e. training vs inference) configuration.
m_ModeSwitched = true;
}
}

}
/// <summary>
/// Returns whether or not the communicator is on.
/// </summary>
/// <returns>
/// <c>true</c>, if communicator is on, <c>false</c> otherwise.
/// </returns>
public bool IsCommunicatorOn()
{
return m_IsCommunicatorOn;
}
/// <summary>
/// Forces the full reset. The done flags are not affected. Is either
/// called the first reset at inference and every external reset
/// at training.

m_ModeSwitched = false;
}
if ((m_IsCommunicatorOn) &&
(m_LastCommunicatorMessageNumber != m_BrainBatcher.GetNumberMessageReceived()))
{
m_LastCommunicatorMessageNumber = m_BrainBatcher.GetNumberMessageReceived();
if (m_BrainBatcher.GetCommand() ==
CommunicatorObjects.CommandProto.Reset)
{
UpdateResetParameters();
SetIsInference(!m_BrainBatcher.GetIsTraining());
ForcedFullReset();
}
if (m_BrainBatcher.GetCommand() ==
CommunicatorObjects.CommandProto.Quit)
{
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
Application.Quit();
return;
}
}
else if (!m_FirstAcademyReset)
if (!m_FirstAcademyReset)
UpdateResetParameters();
ForcedFullReset();
}

}
/// <summary>
/// Monobehavior function that dictates each environment step.
/// MonoBehaviour function that dictates each environment step.
/// </summary>
void FixedUpdate()
{

46
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


using System.Collections.Generic;
using MLAgents.CommunicatorObjects;
using UnityEngine;

/// <summary>
/// User-customizable object for sending structured output from Unity to Python in response
/// to an action in addition to a scalar reward.
/// TODO(cgoy): All references to protobuf objects should be removed.
public CustomObservationProto customObservation;
public CommunicatorObjects.CustomObservationProto customObservation;
/// <summary>
/// Remove the visual observations from memory. Call at each timestep

public string textActions;
public List<float> memories;
public float value;
public CustomActionProto customAction;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
}
/// <summary>

/// Current Agent information (message sent to Brain).
AgentInfo m_Info;
public AgentInfo Info
{
get { return m_Info; }
set { m_Info = value; }
}
/// Current Agent action (message sent from Brain).
AgentAction m_Action;

m_Info.maxStepReached = m_MaxStepReached;
m_Info.id = m_Id;
brain.SendState(this, m_Info);
brain.SubscribeAgentForDecision(this);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{

m_Info.textObservation = "";
}
public void ClearVisualObservations()
{
m_Info.ClearVisualObs();
}
/// <summary>

/// A custom action, defined by the user as custom protobuf message. Useful if the action is hard to encode
/// as either a flat vector or a single string.
/// </param>
public virtual void AgentAction(float[] vectorAction, string textAction, CustomActionProto customAction)
public virtual void AgentAction(float[] vectorAction, string textAction, CommunicatorObjects.CustomActionProto customAction)
{
// We fall back to not using the custom action if the subclassed Agent doesn't override this method.
AgentAction(vectorAction, textAction);

AgentReset();
}
public void UpdateAgentAction(AgentAction action)
{
m_Action = action;
}
/// <summary>
/// Updates the vector action.
/// </summary>

public List<float> GetMemoriesAction()
{
return m_Action.memories;
}
/// <summary>
/// Updates the text action.
/// </summary>
/// <param name="textActions">Text actions.</param>
public void UpdateTextAction(string textActions)
{
m_Action.textActions = textActions;
}
/// <summary>
/// Updates the custom action.
/// </summary>
/// <param name="customAction">Custom action.</param>
public void UpdateCustomAction(CustomActionProto customAction)
{
m_Action.customAction = customAction;
}
/// <summary>

/// Sets the custom observation for the agent for this episode.
/// </summary>
/// <param name="customObservation">New value of the agent's custom observation.</param>
public void SetCustomObservation(CustomObservationProto customObservation)
public void SetCustomObservation(CommunicatorObjects.CustomObservationProto customObservation)
{
m_Info.customObservation = customObservation;
}

24
UnitySDK/Assets/ML-Agents/Scripts/Brain.cs


using System;
using System.Collections.Generic;
using UnityEngine;

/// Brain receive data from Agents through calls to SendState. The brain then updates the
/// Brain receive data from Agents through calls to SubscribeAgentForDecision. The brain then updates the
/// actions of the agents at each FixedUpdate.
/// The Brain encapsulates the decision making process. Every Agent must be assigned a Brain,
/// but you can use the same Brain with more than one Agent. You can also create several

{
[SerializeField] public BrainParameters brainParameters;
protected Dictionary<Agent, AgentInfo> m_AgentInfos =
new Dictionary<Agent, AgentInfo>(1024);
/// <summary>
/// List of agents subscribed for decisions.
/// </summary>
protected List<Agent> m_Agents = new List<Agent>(1024);
[System.NonSerialized]
[NonSerialized]
/// Adds the data of an agent to the current batch so it will be processed in DecideAction.
/// Registers an agent to current batch so it will be processed in DecideAction.
/// <param name="info"></param>
public void SendState(Agent agent, AgentInfo info)
public void SubscribeAgentForDecision(Agent agent)
m_AgentInfos[agent] = info;
m_Agents.Add(agent);
}
/// <summary>

private void LazyInitialize()
protected void LazyInitialize()
{
if (!m_IsInitialized)
{

{
if (m_IsInitialized)
{
m_AgentInfos.Clear();
m_Agents.Clear();
m_IsInitialized = false;
}
}

private void BrainDecideAction()
{
DecideAction();
// Clear the agent Decision subscription collection for the next update cycle.
m_Agents.Clear();
}
/// <summary>

41
UnitySDK/Assets/ML-Agents/Scripts/BrainParameters.cs


using System;
using UnityEngine;
using System.Linq;
namespace MLAgents
{

Continuous
};
}
/// <summary>
/// The resolution of a camera used by an agent.

/// <summary>Defines if the action is discrete or continuous</summary>
public SpaceType vectorActionSpaceType = SpaceType.Discrete;
public BrainParameters()
{
}
/// <summary>
/// Converts Resolution protobuf array to C# Resolution array.
/// </summary>
private static Resolution[] ResolutionProtoToNative(
CommunicatorObjects.ResolutionProto[] resolutionProtos)
{
var localCameraResolutions = new Resolution[resolutionProtos.Length];
for (var i = 0; i < resolutionProtos.Length; i++)
{
localCameraResolutions[i] = new Resolution
{
height = resolutionProtos[i].Height,
width = resolutionProtos[i].Width,
blackAndWhite = resolutionProtos[i].GrayScale
};
}
return localCameraResolutions;
}
public BrainParameters(CommunicatorObjects.BrainParametersProto brainParametersProto)
{
vectorObservationSize = brainParametersProto.VectorObservationSize;
cameraResolutions = ResolutionProtoToNative(
brainParametersProto.CameraResolutions.ToArray()
);
numStackedVectorObservations = brainParametersProto.NumStackedVectorObservations;
vectorActionSize = brainParametersProto.VectorActionSize.ToArray();
vectorActionDescriptions = brainParametersProto.VectorActionDescriptions.ToArray();
vectorActionSpaceType = (SpaceType)brainParametersProto.VectorActionSpaceType;
}
/// <summary>
/// Deep clones the BrainParameter object
/// </summary>

return new BrainParameters()
return new BrainParameters
{
vectorObservationSize = vectorObservationSize,
numStackedVectorObservations = numStackedVectorObservations,

1
UnitySDK/Assets/ML-Agents/Scripts/Demonstration.cs


public float meanReward;
public string demonstrationName;
public const int ApiVersion = 1;
}
}

100
UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs


using System;
using System.Collections.Generic;
using System.Linq;
using Google.Protobuf.Collections;
using MLAgents.CommunicatorObjects;
using UnityEngine;

/// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent
/// </summary>
/// <returns>The BrainInfoProto generated.</returns>
/// <param name="bp">The instance of BrainParameter to extend.</param>
/// <param name="name">The name of the brain.</param>
/// <param name="isTraining">Whether or not the Brain is training.</param>
public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining)

VectorObservationSize = bp.vectorObservationSize,
NumStackedVectorObservations = bp.numStackedVectorObservations,
VectorActionSize = {bp.vectorActionSize},
VectorActionSize = { bp.vectorActionSize },
VectorActionSpaceType =
(SpaceTypeProto)bp.vectorActionSpaceType,
BrainName = name,

throw new Exception("API versions of demonstration are incompatible.");
}
return dm;
}
/// <summary>
/// Converts Resolution protobuf array to C# Resolution array.
/// </summary>
private static Resolution[] ResolutionProtoToNative(IReadOnlyList<ResolutionProto> resolutionProtos)
{
var localCameraResolutions = new Resolution[resolutionProtos.Count];
for (var i = 0; i < resolutionProtos.Count; i++)
{
localCameraResolutions[i] = new Resolution
{
height = resolutionProtos[i].Height,
width = resolutionProtos[i].Width,
blackAndWhite = resolutionProtos[i].GrayScale
};
}
return localCameraResolutions;
}
/// <summary>
/// Convert a BrainParametersProto to a BrainParameters struct.
/// </summary>
/// <param name="bpp">An instance of a brain parameters protobuf object.</param>
/// <returns>A BrainParameters struct.</returns>
public static BrainParameters ToBrainParameters(this BrainParametersProto bpp)
{
var bp = new BrainParameters
{
vectorObservationSize = bpp.VectorObservationSize,
cameraResolutions = ResolutionProtoToNative(
bpp.CameraResolutions
),
numStackedVectorObservations = bpp.NumStackedVectorObservations,
vectorActionSize = bpp.VectorActionSize.ToArray(),
vectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
vectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
};
return bp;
}
/// <summary>
/// Convert a MapField to ResetParameters.
/// </summary>
/// <param name="floatParams">The mapping of strings to floats from a protobuf MapField.</param>
/// <returns></returns>
public static ResetParameters ToResetParameters(this MapField<string, float> floatParams)
{
return new ResetParameters(floatParams);
}
/// <summary>
/// Convert an EnvironmnetParametersProto protobuf object to an EnvironmentResetParameters struct.
/// </summary>
/// <param name="epp">The instance of the EnvironmentParametersProto object.</param>
/// <returns>A new EnvironmentResetParameters struct.</returns>
public static EnvironmentResetParameters ToEnvironmentResetParameters(this EnvironmentParametersProto epp)
{
return new EnvironmentResetParameters
{
resetParameters = epp.FloatParameters?.ToResetParameters(),
customResetParameters = epp.CustomResetParameters
};
}
public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto)
{
return new UnityRLInitParameters
{
seed = inputProto.Seed
};
}
public static AgentAction ToAgentAction(this AgentActionProto aap)
{
return new AgentAction
{
vectorActions = aap.VectorActions.ToArray(),
textActions = aap.TextActions,
memories = aap.Memories.ToList(),
value = aap.Value,
customAction = aap.CustomAction
};
}
public static List<AgentAction> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
{
var agentActions = new List<AgentAction>(proto.Value.Count);
foreach (var ap in proto.Value)
{
agentActions.Add(ap.ToAgentAction());
}
return agentActions;
}
}
}

305
UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs


#if UNITY_EDITOR
using UnityEditor;
#endif
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using MLAgents.CommunicatorObjects;

public class RpcCommunicator : ICommunicator
{
public event QuitCommandHandler QuitCommandReceived;
public event ResetCommandHandler ResetCommandReceived;
public event RLInputReceivedHandler RLInputReceived;
/// The default number of agents in the scene
private const int k_NumAgents = 32;
/// Keeps track of which brains have data to send on the current step
Dictionary<string, bool> m_HasData =
new Dictionary<string, bool>();
/// Keeps track of which brains queried the batcher on the current step
Dictionary<string, bool> m_HasQueried =
new Dictionary<string, bool>();
/// Keeps track of the agents of each brain on the current step
Dictionary<string, List<Agent>> m_CurrentAgents =
new Dictionary<string, List<Agent>>();
/// The current UnityRLOutput to be sent when all the brains queried the batcher
UnityRLOutputProto m_CurrentUnityRlOutput =
new UnityRLOutputProto();
Dictionary<string, Dictionary<Agent, AgentAction>> m_LastActionsReceived =
new Dictionary<string, Dictionary<Agent, AgentAction>>();
CommunicatorParameters m_CommunicatorParameters;
CommunicatorInitParameters m_CommunicatorInitParameters;
/// <param name="communicatorParameters">Communicator parameters.</param>
public RpcCommunicator(CommunicatorParameters communicatorParameters)
/// <param name="communicatorInitParameters">Communicator parameters.</param>
public RpcCommunicator(CommunicatorInitParameters communicatorInitParameters)
m_CommunicatorParameters = communicatorParameters;
m_CommunicatorInitParameters = communicatorInitParameters;
#region Initialization
/// Initialize the communicator by sending the first UnityOutput and receiving the
/// first UnityInput. The second UnityInput is stored in the unityInput argument.
/// Sends the initialization parameters through the Communicator.
/// Is used by the academy to send initialization parameters to the communicator.
/// <returns>The first Unity Input.</returns>
/// <param name="unityOutput">The first Unity Output.</param>
/// <param name="unityInput">The second Unity input.</param>
public UnityInputProto Initialize(UnityOutputProto unityOutput,
/// <returns>The External Initialization Parameters received.</returns>
/// <param name="initParameters">The Unity Initialization Parameters to be sent.</param>
/// <param name="broadcastHub">The BroadcastHub to get the controlled brains.</param>
public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameters,
BroadcastHub broadcastHub)
{
var academyParameters = new UnityRLInitializationOutputProto
{
Name = initParameters.name,
Version = initParameters.version
};
foreach (var brain in initParameters.brains)
{
academyParameters.BrainParameters.Add(brain.brainParameters.ToProto(
brain.name, true));
SubscribeBrain(brain.name);
}
academyParameters.EnvironmentParameters = new EnvironmentParametersProto();
var resetParameters = initParameters.environmentResetParameters.resetParameters;
foreach (var key in resetParameters.Keys)
{
academyParameters.EnvironmentParameters.FloatParameters.Add(key, resetParameters[key]);
}
UnityInputProto input;
UnityInputProto initializationInput;
try
{
initializationInput = Initialize(
new UnityOutputProto
{
RlInitializationOutput = academyParameters
},
out input);
}
catch
{
var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " +
"process is ready to accept communication with Unity.";
// Check for common error condition and add details to the exception message.
var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY");
var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY");
if (httpProxy != null || httpsProxy != null)
{
exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" +
"environment variables and try again.";
}
throw new UnityAgentsException(exceptionMessage);
}
UpdateEnvironmentWithInput(input.RlInput);
return initializationInput.RlInitializationInput.ToUnityRLInitParameters();
}
void UpdateEnvironmentWithInput(UnityRLInputProto rlInput)
{
SendRLInputReceivedEvent(rlInput.IsTraining);
SendCommandEvent(rlInput.Command, rlInput.EnvironmentParameters);
}
private UnityInputProto Initialize(UnityOutputProto unityOutput,
"localhost:" + m_CommunicatorParameters.port,
"localhost:" + m_CommunicatorInitParameters.port,
ChannelCredentials.Insecure);
m_Client = new UnityToExternalProto.UnityToExternalProtoClient(channel);

#endif
}
#endregion
#region Destruction
/// <summary>
/// Ensure that when this object is destructed, the connection is closed.
/// </summary>
~RpcCommunicator()
{
Close();
}
/// <summary>
/// Close the communicator gracefully on both sides of the communication.
/// </summary>

#endif
}
#endregion
#region Sending Events
private void SendCommandEvent(CommandProto command, EnvironmentParametersProto environmentParametersProto)
{
switch (command)
{
case CommandProto.Quit:
{
QuitCommandReceived?.Invoke();
return;
}
case CommandProto.Reset:
{
ResetCommandReceived?.Invoke(environmentParametersProto.ToEnvironmentResetParameters());
return;
}
default:
{
return;
}
}
}
private void SendRLInputReceivedEvent(bool isTraining)
{
RLInputReceived?.Invoke(new UnityRLInputParameters { isTraining = isTraining });
}
#endregion
#region Sending and retreiving data
/// <summary>
/// Adds the brain to the list of brains which will be sending information to External.
/// </summary>
/// <param name="brainKey">Brain key.</param>
private void SubscribeBrain(string brainKey)
{
m_HasQueried[brainKey] = false;
m_HasData[brainKey] = false;
m_CurrentAgents[brainKey] = new List<Agent>(k_NumAgents);
m_CurrentUnityRlOutput.AgentInfos.Add(
brainKey,
new UnityRLOutputProto.Types.ListAgentInfoProto());
}
public void PutObservations(
string brainKey, IEnumerable<Agent> agents)
{
// The brain tried called GiveBrainInfo, update m_hasQueried
m_HasQueried[brainKey] = true;
// Populate the currentAgents dictionary
m_CurrentAgents[brainKey].Clear();
foreach (var agent in agents)
{
m_CurrentAgents[brainKey].Add(agent);
}
// If at least one agent has data to send, then append data to
// the message and update hasSentState
if (m_CurrentAgents[brainKey].Count > 0)
{
foreach (var agent in m_CurrentAgents[brainKey])
{
var agentInfoProto = agent.Info.ToProto();
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
// Avoid visual obs memory leak. This should be called AFTER we are done with the visual obs.
// e.g. after recording them to demo and using them for inference.
agent.ClearVisualObservations();
}
m_HasData[brainKey] = true;
}
// If any agent needs to send data, then the whole message
// must be sent
if (m_HasQueried.Values.All(x => x))
{
if (m_HasData.Values.Any(x => x))
{
SendBatchedMessageHelper();
}
// The message was just sent so we must reset hasSentState and
// triedSendState
foreach (var k in m_CurrentAgents.Keys)
{
m_HasData[k] = false;
m_HasQueried[k] = false;
}
}
}
/// <summary>
/// Helper method that sends the current UnityRLOutput, receives the next UnityInput and
/// Applies the appropriate AgentAction to the agents.
/// </summary>
void SendBatchedMessageHelper()
{
var input = Exchange(
new UnityOutputProto
{
RlOutput = m_CurrentUnityRlOutput
});
foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys)
{
m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear();
}
var rlInput = input?.RlInput;
if (rlInput?.AgentActions == null)
{
return;
}
UpdateEnvironmentWithInput(rlInput);
m_LastActionsReceived.Clear();
foreach (var brainName in rlInput.AgentActions.Keys)
{
if (!m_CurrentAgents[brainName].Any())
{
continue;
}
if (!rlInput.AgentActions[brainName].Value.Any())
{
continue;
}
var agentActions = rlInput.AgentActions[brainName].ToAgentActionList();
var numAgents = m_CurrentAgents[brainName].Count;
var agentActionDict = new Dictionary<Agent, AgentAction>(numAgents);
m_LastActionsReceived[brainName] = agentActionDict;
for (var i = 0; i < numAgents; i++)
{
var agent = m_CurrentAgents[brainName][i];
var agentAction = agentActions[i];
agentActionDict[agent] = agentAction;
agent.UpdateAgentAction(agentAction);
}
}
}
public Dictionary<Agent, AgentAction> GetActions(string key)
{
return m_LastActionsReceived[key];
}
public UnityInputProto Exchange(UnityOutputProto unityOutput)
private UnityInputProto Exchange(UnityOutputProto unityOutput)
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
if (!m_IsOpen)

{
return message.UnityInput;
}
else
{
m_IsOpen = false;
return null;
}
m_IsOpen = false;
// Not sure if the quit command is actually sent when a
// non 200 message is received. Notify that we are indeed
// quitting.
QuitCommandReceived?.Invoke();
return message.UnityInput;
QuitCommandReceived?.Invoke();
return null;
}
#else

};
}
/// <summary>
/// When the Unity application quits, the communicator must be closed
/// </summary>
private void OnApplicationQuit()
{
Close();
}
#endregion
#if UNITY_EDITOR
#if UNITY_2017_2_OR_NEWER

27
UnitySDK/Assets/ML-Agents/Scripts/HeuristicBrain.cs


throw new UnityAgentsException(
"The Brain is set to Heuristic, but no decision script attached to it");
}
foreach (var agent in m_AgentInfos.Keys)
foreach (var agent in m_Agents)
var info = agent.Info;
m_AgentInfos[agent].stackedVectorObservation,
m_AgentInfos[agent].visualObservations,
m_AgentInfos[agent].reward,
m_AgentInfos[agent].done,
m_AgentInfos[agent].memories));
info.stackedVectorObservation,
info.visualObservations,
info.reward,
info.done,
info.memories));
foreach (var agent in m_AgentInfos.Keys)
foreach (var agent in m_Agents)
var info = agent.Info;
m_AgentInfos[agent].stackedVectorObservation,
m_AgentInfos[agent].visualObservations,
m_AgentInfos[agent].reward,
m_AgentInfos[agent].done,
m_AgentInfos[agent].memories));
info.stackedVectorObservation,
info.visualObservations,
info.reward,
info.done,
info.memories));
m_AgentInfos.Clear();
}
}
}

123
UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs


using System;
using System.Collections.Generic;
public struct CommunicatorParameters
public struct EnvironmentResetParameters
{
/// <summary>
/// Mapping of string : float which defines which parameters can be
/// reset from python.
/// </summary>
public ResetParameters resetParameters;
/// <summary>
/// The protobuf for custom reset parameters.
/// NOTE: This is the last remaining relic of gRPC protocol
/// that is left in our code. We need to decide how to handle this
/// moving forward.
/// </summary>
public CustomResetParametersProto customResetParameters;
}
public struct CommunicatorInitParameters
/// <summary>
/// Port to listen for connections on.
/// </summary>
/// <summary>
/// The name of the environment.
/// </summary>
public string name;
/// <summary>
/// The version of the Unity SDK.
/// </summary>
public string version;
/// <summary>
/// The list of brains parameters used for training.
/// </summary>
public IEnumerable<Brain> brains;
/// <summary>
/// The set of environment parameters defined by the user that will be sent to the communicator.
/// </summary>
public EnvironmentResetParameters environmentResetParameters;
}
public struct UnityRLInitParameters
{
/// <summary>
/// An RNG seed sent from the python process to Unity.
/// </summary>
public int seed;
}
public struct UnityRLInputParameters
{
/// <summary>
/// Boolean sent back from python to indicate whether or not training is happening.
/// </summary>
public bool isTraining;
/// <summary>
/// Delegate for handling quite events sent back from the communicator.
/// </summary>
public delegate void QuitCommandHandler();
/// <summary>
/// Delegate for handling reset parameter updates sent from the communicator.
/// </summary>
/// <param name="resetParams"></param>
public delegate void ResetCommandHandler(EnvironmentResetParameters resetParams);
/// <summary>
/// Delegate to handle UnityRLInputParameters updates from the communicator.
/// </summary>
/// <param name="inputParams"></param>
public delegate void RLInputReceivedHandler(UnityRLInputParameters inputParams);
/**
This is the interface of the Communicators.
This does not need to be modified nor implemented to create a Unity environment.

......UnityRLOutput
......UnityRLInitializationOutput
...UnityInput
......UnityRLIntput
......UnityRLInitializationIntput
......UnityRLInput
......UnityRLInitializationInput
UnityOutput and UnityInput can be extended to provide functionalities beyond RL
UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities

/// <summary>
/// Initialize the communicator by sending the first UnityOutput and receiving the
/// first UnityInput. The second UnityInput is stored in the unityInput argument.
/// Quit was received by the communicator.
/// </summary>
event QuitCommandHandler QuitCommandReceived;
/// <summary>
/// Reset command sent back from the communicator.
/// </summary>
event ResetCommandHandler ResetCommandReceived;
/// <summary>
/// Unity RL Input was received by the communicator.
/// </summary>
event RLInputReceivedHandler RLInputReceived;
/// <summary>
/// Sends the academy parameters through the Communicator.
/// Is used by the academy to send the AcademyParameters to the communicator.
/// </summary>
/// <returns>The External Initialization Parameters received.</returns>
/// <param name="initParameters">The Unity Initialization Parameters to be sent.</param>
/// <param name="broadcastHub">The BroadcastHub to get the controlled brains.</param>
UnityRLInitParameters Initialize(CommunicatorInitParameters initParameters,
BroadcastHub broadcastHub);
/// <summary>
/// Sends the observations. If at least one brain has an agent in need of
/// a decision or if the academy is done, the data is sent via
/// Communicator. Else, a new step is realized. The data can only be
/// sent once all the brains that subscribed to the batcher have tried
/// to send information.
/// <returns>The first Unity Input.</returns>
/// <param name="unityOutput">The first Unity Output.</param>
/// <param name="unityInput">The second Unity input.</param>
UnityInputProto Initialize(UnityOutputProto unityOutput,
out UnityInputProto unityInput);
/// <param name="key">Batch Key.</param>
/// <param name="agents">Agent info.</param>
void PutObservations(string key, IEnumerable<Agent> agents);
/// Send a UnityOutput and receives a UnityInput.
/// Gets the AgentActions based on the batching key.
/// <returns>The next UnityInput.</returns>
/// <param name="unityOutput">The UnityOutput to be sent.</param>
UnityInputProto Exchange(UnityOutputProto unityOutput);
/// <param name="key">A key to identify which actions to get</param>
/// <returns></returns>
Dictionary<Agent, AgentAction> GetActions(string key);
/// <summary>
/// Close the communicator gracefully on both sides of the communication.

23
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs


/// </summary>
public class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo)
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
{
var action = new float[actionSize];
for (var j = 0; j < actionSize; j++)

m_Allocator = allocator;
}
public void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo)
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
var batchSize = agentInfo.Keys.Count;
var agentsArray = agents as List<Agent> ?? agents.ToList();
var batchSize = agentsArray.Count;
var actions = new float[batchSize, m_ActionSize.Length];
var startActionIndices = Utilities.CumSum(m_ActionSize);
for (var actionIndex = 0; actionIndex < m_ActionSize.Length; actionIndex++)

outputTensor.data.Dispose();
}
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
foreach (var agent in agentsArray)
{
var action = new float[m_ActionSize.Length];
for (var j = 0; j < m_ActionSize.Length; j++)

m_MemoryIndex = memoryIndex;
}
public void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo)
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
{
var memory = agent.GetMemoriesAction();

/// </summary>
public class MemoryOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo)
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
{
var memory = new List<float>();
for (var j = 0; j < memorySize; j++)

/// </summary>
public class ValueEstimateApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo)
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
{
agent.UpdateValueAction(tensorProxy.data[agentIndex, 0]);
agentIndex++;

49
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
}

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
tensorProxy.data?.Dispose();
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
tensorProxy.shape = new long[0];
tensorProxy.data?.Dispose();

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
var vectorObs = agentInfo[agent].stackedVectorObservation;
var info = agent.Info;
var vectorObs = info.stackedVectorObservation;
for (var j = 0; j < vecObsSizeT; j++)
{
tensorProxy.data[agentIndex, j] = vectorObs[j];

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
var memory = agentInfo[agent].memories;
var info = agent.Info;
var memory = info.memories;
if (memory == null)
{
agentIndex++;

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
var memory = agentInfo[agent].memories;
var agentInfo = agent.Info;
var memory = agentInfo.memories;
var offset = memorySize * m_MemoryIndex;

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
var pastAction = agentInfo[agent].storedVectorActions;
var info = agent.Info;
var pastAction = info.storedVectorActions;
for (var j = 0; j < actionSize; j++)
{
tensorProxy.data[agentIndex, j] = pastAction[j];

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
foreach (var agent in agentInfo.Keys)
foreach (var agent in agents)
var maskList = agentInfo[agent].actionMasks;
var agentInfo = agent.Info;
var maskList = agentInfo.actionMasks;
for (var j = 0; j < maskSize; j++)
{
var isUnmasked = (maskList != null && maskList[j]) ? 0.0f : 1.0f;

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);

}
public void Generate(
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
var textures = agentInfo.Keys.Select(
agent => agentInfo[agent].visualObservations[m_Index]).ToList();
var textures = agents.Select(
agent => agent.Info.visualObservations[m_Index]).ToList();
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
Utilities.TextureToTensorProxy(textures, tensorProxy, m_GrayScale);

14
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs


/// <param name="tensorProxy">
/// The Tensor containing the data to be applied to the Agents
/// </param>
/// <param name="agentInfo">
/// Dictionary of Agents to AgentInfo that will receive
/// the values of the Tensor.
/// <param name="agents">
/// List of Agents that will receive the values of the Tensor.
void Apply(TensorProxy tensorProxy, Dictionary<Agent, AgentInfo> agentInfo);
void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents);
}
private readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();

/// Updates the state of the agents based on the data present in the tensor.
/// </summary>
/// <param name="tensors"> Enumerable of tensors containing the data.</param>
/// <param name="agentInfos"> Dictionary of Agent to AgentInfo that contains the
/// Agents that will be updated using the tensor's data</param>
/// <param name="agents"> List of Agents that will be updated using the tensor's data</param>
IEnumerable<TensorProxy> tensors, Dictionary<Agent, AgentInfo> agentInfos)
IEnumerable<TensorProxy> tensors, IEnumerable<Agent> agents)
{
foreach (var tensor in tensors)
{

$"Unknown tensorProxy expected as output : {tensor.name}");
}
m_Dict[tensor.name].Apply(tensor, agentInfos);
m_Dict[tensor.name].Apply(tensor, agents);
}
}
}

10
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs


/// </summary>
/// <param name="tensorProxy"> The tensor the data and shape will be modified</param>
/// <param name="batchSize"> The number of agents present in the current batch</param>
/// <param name="agentInfo"> Dictionary of Agent to AgentInfo containing the
/// <param name="agents"> List of Agents containing the
TensorProxy tensorProxy, int batchSize, Dictionary<Agent, AgentInfo> agentInfo);
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents);
}
private readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();

/// <param name="tensors"> Enumerable of tensors that will be modified.</param>
/// <param name="currentBatchSize"> The number of agents present in the current batch
/// </param>
/// <param name="agentInfos"> Dictionary of Agent to AgentInfo that contains the
/// <param name="agents"> List of Agents that contains the
/// data that will be used to modify the tensors</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>

Dictionary<Agent, AgentInfo> agentInfos)
IEnumerable<Agent> agents)
{
foreach (var tensor in tensors)
{

$"Unknown tensorProxy expected as input : {tensor.name}");
}
m_Dict[tensor.name].Generate(tensor, currentBatchSize, agentInfos);
m_Dict[tensor.name].Generate(tensor, currentBatchSize, agents);
}
}
}

26
UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs


[CreateAssetMenu(fileName = "NewLearningBrain", menuName = "ML-Agents/Learning Brain")]
public class LearningBrain : Brain
{
private Batcher m_Batcher;
private ITensorAllocator m_TensorAllocator;
private TensorGenerator m_TensorGenerator;
private TensorApplier m_TensorApplier;

private IReadOnlyList<TensorProxy> m_InferenceInputs;
private IReadOnlyList<TensorProxy> m_InferenceOutputs;
protected ICommunicator m_Communicator;
/// When Called, the brain will be controlled externally. It will not use the
/// model to decide on actions.
/// Sets the Batcher of the Brain. The brain will call the communicator at every step and give
/// it the agent's data using PutObservations at each DecideAction call.
public void SetBatcher(Batcher batcher)
/// <param name="communicator"> The Batcher the brain will use for the current session</param>
public void SetCommunicator(ICommunicator communicator)
m_Batcher = batcher;
m_Batcher?.SubscribeBrain(name);
m_Communicator = communicator;
LazyInitialize();
}
/// <inheritdoc />

/// <inheritdoc />
protected override void DecideAction()
{
m_Batcher?.SendBrainInfo(name, m_AgentInfos);
if (m_Batcher != null)
if (m_Communicator != null)
m_AgentInfos.Clear();
m_Communicator?.PutObservations(name, m_Agents);
var currentBatchSize = m_AgentInfos.Count();
var currentBatchSize = m_Agents.Count;
if (currentBatchSize == 0)
{
return;

Profiler.BeginSample($"MLAgents.{name}.GenerateTensors");
// Prepare the input tensors to be feed into the engine
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_AgentInfos);
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Agents);
Profiler.EndSample();
Profiler.BeginSample($"MLAgents.{name}.PrepareBarracudaInputs");

Profiler.BeginSample($"MLAgents.{name}.ApplyTensors");
// Update the outputs
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_AgentInfos);
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_Agents);
m_AgentInfos.Clear();
Profiler.EndSample();
}

5
UnitySDK/Assets/ML-Agents/Scripts/PlayerBrain.cs


{
if (brainParameters.vectorActionSpaceType == SpaceType.Continuous)
{
foreach (var agent in m_AgentInfos.Keys)
foreach (var agent in m_Agents)
{
var action = new float[brainParameters.vectorActionSize[0]];
foreach (var cha in keyContinuousPlayerActions)

}
else
{
foreach (var agent in m_AgentInfos.Keys)
foreach (var agent in m_Agents)
{
var action = new float[brainParameters.vectorActionSize.Length];
foreach (var dha in discretePlayerActions)

agent.UpdateVectorAction(action);
}
}
m_AgentInfos.Clear();
}
}
}

25
UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs


public float value;
}
[FormerlySerializedAs("resetParameters")]
[SerializeField] private List<ResetParameter> m_ResetParameters = new List<ResetParameter>();
public ResetParameters() {}
public ResetParameters(IDictionary<string, float> dict) : base(dict)
{
UpdateResetParameters();
}
public void OnBeforeSerialize()
private void UpdateResetParameters()
var rp = new ResetParameter();
rp.key = pair.Key;
m_ResetParameters.Add(new ResetParameter { key = pair.Key, value = pair.Value });
}
}
rp.value = pair.Value;
m_ResetParameters.Add(rp);
}
[FormerlySerializedAs("resetParameters")]
[SerializeField] private List<ResetParameter> m_ResetParameters = new List<ResetParameter>();
public void OnBeforeSerialize()
{
UpdateResetParameters();
}
public void OnAfterDeserialize()

21
UnitySDK/Assets/ML-Agents/Scripts/Timer.cs


namespace MLAgents
{
[DataContract]
public class TimerNode
{

/// <summary>
/// Child nodes, indexed by name.
/// </summary>
[DataMember(Name="children", Order=999)]
[DataMember(Name = "children", Order = 999)]
Dictionary<string, TimerNode> m_Children;
/// <summary>

/// <summary>
/// Number of times the corresponding code block has been called.
/// </summary>
[DataMember(Name="count")]
[DataMember(Name = "count")]
int m_NumCalls = 0;
/// <summary>

/// <summary>
/// Total elapsed seconds.
/// </summary>
[DataMember(Name="total")]
[DataMember(Name = "total")]
set { } // Serialization needs this, but unused.
set {} // Serialization needs this, but unused.
[DataMember(Name="self")]
[DataMember(Name = "self")]
public double SelfSeconds
{
get

{
foreach(var child in m_Children.Values)
foreach (var child in m_Children.Values)
{
totalChildTicks += child.m_TotalTicks;
}

return selfTicks * s_TicksToSeconds;
}
set { } // Serialization needs this, but unused.
set {} // Serialization needs this, but unused.
}
public IReadOnlyDictionary<string, TimerNode> Children

get { return m_NumCalls; }
}
public TimerNode(string name, bool isRoot=false)
public TimerNode(string name, bool isRoot = false)
{
m_FullName = name;
if (isRoot)

Reset();
}
public void Reset(string name="root")
public void Reset(string name = "root")
{
m_Stack = new Stack<TimerNode>();
m_RootNode = new TimerNode(name, true);

/// If the filename is null, a default one will be used.
/// </summary>
/// <param name="filename"></param>
public void SaveJsonTimers(string filename=null)
public void SaveJsonTimers(string filename = null)
{
if (filename == null)
{

1
UnitySDK/UnitySDK.sln.DotSettings


<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=CPU/@EntryIndexedValue">CPU</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=GPU/@EntryIndexedValue">GPU</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=NN/@EntryIndexedValue">NN</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=RL/@EntryIndexedValue">RL</s:String>
<s:Boolean x:Key="/Default/UserDictionary/Words/=BLAS/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Logits/@EntryIndexedValue">True</s:Boolean>

13
UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs.meta


fileFormatVersion: 2
guid: 4243d5dc0ad5746cba578575182f8c17
timeCreated: 1523045876
licenseType: Free
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

181
UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs


using Google.Protobuf;
using System.Net.Sockets;
using UnityEngine;
using MLAgents.CommunicatorObjects;
using System.Threading.Tasks;
#if UNITY_EDITOR
using UnityEditor;
#endif
namespace MLAgents
{
public class SocketCommunicator : ICommunicator
{
private const float k_TimeOut = 10f;
private const int k_MessageLength = 12000;
byte[] m_MessageHolder = new byte[k_MessageLength];
int m_ComPort;
Socket m_Sender;
byte[] m_LengthHolder = new byte[4];
CommunicatorParameters m_CommunicatorParameters;
public SocketCommunicator(CommunicatorParameters communicatorParameters)
{
m_CommunicatorParameters = communicatorParameters;
}
/// <summary>
/// Initialize the communicator by sending the first UnityOutput and receiving the
/// first UnityInput. The second UnityInput is stored in the unityInput argument.
/// </summary>
/// <returns>The first Unity Input.</returns>
/// <param name="unityOutput">The first Unity Output.</param>
/// <param name="unityInput">The second Unity input.</param>
public UnityInputProto Initialize(UnityOutputProto unityOutput,
out UnityInputProto unityInput)
{
m_Sender = new Socket(
AddressFamily.InterNetwork,
SocketType.Stream,
ProtocolType.Tcp);
m_Sender.Connect("localhost", m_CommunicatorParameters.port);
var initializationInput =
UnityMessageProto.Parser.ParseFrom(Receive());
Send(WrapMessage(unityOutput, 200).ToByteArray());
unityInput = UnityMessageProto.Parser.ParseFrom(Receive()).UnityInput;
#if UNITY_EDITOR
#if UNITY_2017_2_OR_NEWER
EditorApplication.playModeStateChanged += HandleOnPlayModeChanged;
#else
EditorApplication.playmodeStateChanged += HandleOnPlayModeChanged;
#endif
#endif
return initializationInput.UnityInput;
}
/// <summary>
/// Uses the socke to receive a byte[] from External. Reassemble a message that was split
/// by External if it was too long.
/// </summary>
/// <returns>The byte[] sent by External.</returns>
byte[] Receive()
{
m_Sender.Receive(m_LengthHolder);
var totalLength = System.BitConverter.ToInt32(m_LengthHolder, 0);
var location = 0;
var result = new byte[totalLength];
while (location != totalLength)
{
var fragment = m_Sender.Receive(m_MessageHolder);
System.Buffer.BlockCopy(
m_MessageHolder, 0, result, location, fragment);
location += fragment;
}
return result;
}
/// <summary>
/// Send the specified input via socket to External. Split the message into smaller
/// parts if it is too long.
/// </summary>
/// <param name="input">The byte[] to be sent.</param>
void Send(byte[] input)
{
var newArray = new byte[input.Length + 4];
input.CopyTo(newArray, 4);
System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0);
m_Sender.Send(newArray);
}
/// <summary>
/// Close the communicator gracefully on both sides of the communication.
/// </summary>
public void Close()
{
Send(WrapMessage(null, 400).ToByteArray());
}
/// <summary>
/// Send a UnityOutput and receives a UnityInput.
/// </summary>
/// <returns>The next UnityInput.</returns>
/// <param name="unityOutput">The UnityOutput to be sent.</param>
public UnityInputProto Exchange(UnityOutputProto unityOutput)
{
Send(WrapMessage(unityOutput, 200).ToByteArray());
byte[] received = null;
var task = Task.Run(() => received = Receive());
if (!task.Wait(System.TimeSpan.FromSeconds(k_TimeOut)))
{
throw new UnityAgentsException(
"The communicator took too long to respond.");
}
var message = UnityMessageProto.Parser.ParseFrom(received);
if (message.Header.Status != 200)
{
return null;
}
return message.UnityInput;
}
/// <summary>
/// Wraps the UnityOuptut into a message with the appropriate status.
/// </summary>
/// <returns>The UnityMessage corresponding.</returns>
/// <param name="content">The UnityOutput to be wrapped.</param>
/// <param name="status">The status of the message.</param>
private static UnityMessageProto WrapMessage(UnityOutputProto content, int status)
{
return new UnityMessageProto
{
Header = new HeaderProto { Status = status },
UnityOutput = content
};
}
/// <summary>
/// When the Unity application quits, the communicator must be closed
/// </summary>
private void OnApplicationQuit()
{
Close();
}
#if UNITY_EDITOR
#if UNITY_2017_2_OR_NEWER
/// <summary>
/// When the editor exits, the communicator must be closed
/// </summary>
/// <param name="state">State.</param>
private void HandleOnPlayModeChanged(PlayModeStateChange state)
{
// This method is run whenever the playmode state is changed.
if (state == PlayModeStateChange.ExitingPlayMode)
{
Close();
}
}
#else
/// <summary>
/// When the editor exits, the communicator must be closed
/// </summary>
private void HandleOnPlayModeChanged()
{
// This method is run whenever the playmode state is changed.
if (!EditorApplication.isPlayingOrWillChangePlaymode)
{
Close();
}
}
#endif
#endif
}
}

13
UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta


fileFormatVersion: 2
guid: f0901c57c84a54f25aa5955165072493
timeCreated: 1523046536
licenseType: Free
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

291
UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs


using System.Collections.Generic;
using System.Linq;
using System;
using UnityEngine;
namespace MLAgents
{
/// <summary>
/// The batcher is an RL specific class that makes sure that the information each object in
/// Unity (Academy and Brains) wants to send to External is appropriately batched together
/// and sent only when necessary.
///
/// The Batcher will only send a Message to the Communicator when either :
/// 1 - The academy is done
/// 2 - At least one brain has data to send
///
/// At each step, the batcher will keep track of the brains that queried the batcher for that
/// step. The batcher can only send the batched data when all the Brains have queried the
/// Batcher.
/// </summary>
public class Batcher
{
/// The default number of agents in the scene
private const int k_NumAgents = 32;
/// Keeps track of which brains have data to send on the current step
Dictionary<string, bool> m_HasData =
new Dictionary<string, bool>();
/// Keeps track of which brains queried the batcher on the current step
Dictionary<string, bool> m_HasQueried =
new Dictionary<string, bool>();
/// Keeps track of the agents of each brain on the current step
Dictionary<string, List<Agent>> m_CurrentAgents =
new Dictionary<string, List<Agent>>();
/// The Communicator of the batcher, sends a message at most once per step
ICommunicator m_Communicator;
/// The current UnityRLOutput to be sent when all the brains queried the batcher
CommunicatorObjects.UnityRLOutputProto m_CurrentUnityRlOutput =
new CommunicatorObjects.UnityRLOutputProto();
/// Keeps track of last CommandProto sent by External
CommunicatorObjects.CommandProto m_Command;
/// Keeps track of last EnvironmentParametersProto sent by External
CommunicatorObjects.EnvironmentParametersProto m_EnvironmentParameters;
/// Keeps track of last training mode sent by External
bool m_IsTraining;
/// Keeps track of the number of messages received
private ulong m_MessagesReceived;
/// <summary>
/// Initializes a new instance of the Batcher class.
/// </summary>
/// <param name="communicator">The communicator to be used by the batcher.</param>
public Batcher(ICommunicator communicator)
{
m_Communicator = communicator;
}
/// <summary>
/// Sends the academy parameters through the Communicator.
/// Is used by the academy to send the AcademyParameters to the communicator.
/// </summary>
/// <returns>The External Initialization Parameters received.</returns>
/// <param name="academyParameters">The Unity Initialization Parameters to be sent.</param>
public CommunicatorObjects.UnityRLInitializationInputProto SendAcademyParameters(
CommunicatorObjects.UnityRLInitializationOutputProto academyParameters)
{
CommunicatorObjects.UnityInputProto input;
var initializationInput = new CommunicatorObjects.UnityInputProto();
try
{
initializationInput = m_Communicator.Initialize(
new CommunicatorObjects.UnityOutputProto
{
RlInitializationOutput = academyParameters
},
out input);
}
catch
{
var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " +
"process is ready to accept communication with Unity.";
// Check for common error condition and add details to the exception message.
var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY");
var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY");
if (httpProxy != null || httpsProxy != null)
{
exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" +
"environment variables and try again.";
}
throw new UnityAgentsException(exceptionMessage);
}
var firstRlInput = input.RlInput;
m_Command = firstRlInput.Command;
m_EnvironmentParameters = firstRlInput.EnvironmentParameters;
m_IsTraining = firstRlInput.IsTraining;
return initializationInput.RlInitializationInput;
}
/// <summary>
/// Gets the command. Is used by the academy to get reset or quit signals.
/// </summary>
/// <returns>The current command.</returns>
public CommunicatorObjects.CommandProto GetCommand()
{
return m_Command;
}
/// <summary>
/// Gets the number of messages received so far. Can be used to check for new messages.
/// </summary>
/// <returns>The number of messages received since start of the simulation</returns>
public ulong GetNumberMessageReceived()
{
return m_MessagesReceived;
}
/// <summary>
/// Gets the environment parameters. Is used by the academy to update
/// the environment parameters.
/// </summary>
/// <returns>The environment parameters.</returns>
public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters()
{
return m_EnvironmentParameters;
}
/// <summary>
/// Gets the last training_mode flag External sent
/// </summary>
/// <returns><c>true</c>, if training mode is requested, <c>false</c> otherwise.</returns>
public bool GetIsTraining()
{
return m_IsTraining;
}
/// <summary>
/// Adds the brain to the list of brains which will be sending information to External.
/// </summary>
/// <param name="brainKey">Brain key.</param>
public void SubscribeBrain(string brainKey)
{
m_HasQueried[brainKey] = false;
m_HasData[brainKey] = false;
m_CurrentAgents[brainKey] = new List<Agent>(k_NumAgents);
m_CurrentUnityRlOutput.AgentInfos.Add(
brainKey,
new CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto());
}
/// <summary>
/// Sends the brain info. If at least one brain has an agent in need of
/// a decision or if the academy is done, the data is sent via
/// Communicator. Else, a new step is realized. The data can only be
/// sent once all the brains that subscribed to the batcher have tried
/// to send information.
/// </summary>
/// <param name="brainKey">Brain key.</param>
/// <param name="agentInfo">Agent info.</param>
public void SendBrainInfo(
string brainKey, Dictionary<Agent, AgentInfo> agentInfo)
{
// If no communicator is initialized, the Batcher will not transmit
// BrainInfo
if (m_Communicator == null)
{
return;
}
// The brain tried called GiveBrainInfo, update m_hasQueried
m_HasQueried[brainKey] = true;
// Populate the currentAgents dictionary
m_CurrentAgents[brainKey].Clear();
foreach (var agent in agentInfo.Keys)
{
m_CurrentAgents[brainKey].Add(agent);
}
// If at least one agent has data to send, then append data to
// the message and update hasSentState
if (m_CurrentAgents[brainKey].Count > 0)
{
foreach (var agent in m_CurrentAgents[brainKey])
{
var agentInfoProto = agentInfo[agent].ToProto();
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
// Avoid visual obs memory leak. This should be called AFTER we are done with the visual obs.
// e.g. after recording them to demo and using them for inference.
agentInfo[agent].ClearVisualObs();
}
m_HasData[brainKey] = true;
}
// If any agent needs to send data, then the whole message
// must be sent
if (m_HasQueried.Values.All(x => x))
{
if (m_HasData.Values.Any(x => x))
{
SendBatchedMessageHelper();
}
// The message was just sent so we must reset hasSentState and
// triedSendState
foreach (var k in m_CurrentAgents.Keys)
{
m_HasData[k] = false;
m_HasQueried[k] = false;
}
}
}
/// <summary>
/// Helper method that sends the current UnityRLOutput, receives the next UnityInput and
/// Applies the appropriate AgentAction to the agents.
/// </summary>
void SendBatchedMessageHelper()
{
var input = m_Communicator.Exchange(
new CommunicatorObjects.UnityOutputProto
{
RlOutput = m_CurrentUnityRlOutput
});
m_MessagesReceived += 1;
foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys)
{
m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear();
}
if (input == null)
{
m_Command = CommunicatorObjects.CommandProto.Quit;
return;
}
var rlInput = input.RlInput;
if (rlInput == null)
{
m_Command = CommunicatorObjects.CommandProto.Quit;
return;
}
m_Command = rlInput.Command;
m_EnvironmentParameters = rlInput.EnvironmentParameters;
m_IsTraining = rlInput.IsTraining;
if (rlInput.AgentActions == null)
{
return;
}
foreach (var brainName in rlInput.AgentActions.Keys)
{
if (!m_CurrentAgents[brainName].Any())
{
continue;
}
if (!rlInput.AgentActions[brainName].Value.Any())
{
continue;
}
for (var i = 0; i < m_CurrentAgents[brainName].Count; i++)
{
var agent = m_CurrentAgents[brainName][i];
var action = rlInput.AgentActions[brainName].Value[i];
agent.UpdateVectorAction(action.VectorActions.ToArray());
agent.UpdateMemoriesAction(action.Memories.ToList());
agent.UpdateTextAction(action.TextActions);
agent.UpdateValueAction(action.Value);
agent.UpdateCustomAction(action.CustomAction);
}
}
}
}
}
正在加载...
取消
保存