using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Google.Protobuf;
namespace MLAgents
{
///
/// The batcher is an RL specific class that makes sure that the information each object in
/// Unity (Academy and Brains) wants to send to External is appropriately batched together
/// and sent only when necessary.
///
/// The Batcher will only send a Message to the Communicator when either :
/// 1 - The academy is done
/// 2 - At least one brain has data to send
///
/// At each step, the batcher will keep track of the brains that queried the batcher for that
/// step. The batcher can only send the batched data when all the Brains have queried the
/// Batcher.
///
public class Batcher
{
/// The default number of agents in the scene
private const int NumAgents = 32;
/// Keeps track of which brains have data to send on the current step
Dictionary m_hasData =
new Dictionary();
/// Keeps track of which brains queried the batcher on the current step
Dictionary m_hasQueried =
new Dictionary();
/// Keeps track of the agents of each brain on the current step
Dictionary> m_currentAgents =
new Dictionary>();
/// The Communicator of the batcher, sends a message at most once per step
Communicator m_communicator;
/// The current UnityRLOutput to be sent when all the brains queried the batcher
CommunicatorObjects.UnityRLOutput m_currentUnityRLOutput =
new CommunicatorObjects.UnityRLOutput();
/// Keeps track of the done flag of the Academy
bool m_academyDone;
/// Keeps track of last CommandProto sent by External
CommunicatorObjects.CommandProto m_command;
/// Keeps track of last EnvironmentParametersProto sent by External
CommunicatorObjects.EnvironmentParametersProto m_environmentParameters;
/// Keeps track of last training mode sent by External
bool m_isTraining;
/// Keeps track of the number of messages received
private ulong m_messagesReceived;
///
/// Initializes a new instance of the Batcher class.
///
/// The communicator to be used by the batcher.
public Batcher(Communicator communicator)
{
this.m_communicator = communicator;
}
///
/// Sends the academy parameters through the Communicator.
/// Is used by the academy to send the AcademyParameters to the communicator.
///
/// The External Initialization Parameters received.
/// The Unity Initialization Paramters to be sent.
public CommunicatorObjects.UnityRLInitializationInput SendAcademyParameters(
CommunicatorObjects.UnityRLInitializationOutput academyParameters)
{
CommunicatorObjects.UnityInput input;
var initializationInput = new CommunicatorObjects.UnityInput();
try
{
initializationInput = m_communicator.Initialize(
new CommunicatorObjects.UnityOutput
{
RlInitializationOutput = academyParameters
},
out input);
}
catch
{
throw new UnityAgentsException(
"The Communicator was unable to connect. Please make sure the External " +
"process is ready to accept communication with Unity.");
}
var firstRlInput = input.RlInput;
m_command = firstRlInput.Command;
m_environmentParameters = firstRlInput.EnvironmentParameters;
m_isTraining = firstRlInput.IsTraining;
return initializationInput.RlInitializationInput;
}
///
/// Registers the done flag of the academy to the next output to be sent
/// to the communicator.
///
/// If set to true
/// The academy done state will be sent to External at the next Exchange.
public void RegisterAcademyDoneFlag(bool done)
{
m_academyDone = done;
}
///
/// Gets the command. Is used by the academy to get reset or quit signals.
///
/// The current command.
public CommunicatorObjects.CommandProto GetCommand()
{
return m_command;
}
///
/// Gets the number of messages received so far. Can be used to check for new messages.
///
/// The number of messages received since start of the simulation
public ulong GetNumberMessageReceived()
{
return m_messagesReceived;
}
///
/// Gets the environment parameters. Is used by the academy to update
/// the environment parameters.
///
/// The environment parameters.
public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters()
{
return m_environmentParameters;
}
///
/// Gets the last training_mode flag External sent
///
/// true, if training mode is requested, false otherwise.
public bool GetIsTraining()
{
return m_isTraining;
}
///
/// Adds the brain to the list of brains which will be sending information to External.
///
/// Brain key.
public void SubscribeBrain(string brainKey)
{
m_hasQueried[brainKey] = false;
m_hasData[brainKey] = false;
m_currentAgents[brainKey] = new List(NumAgents);
m_currentUnityRLOutput.AgentInfos.Add(
brainKey,
new CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto());
}
///
/// Converts a AgentInfo to a protobuffer generated AgentInfoProto
///
/// The protobuf verison of the AgentInfo.
/// The AgentInfo to convert.
public static CommunicatorObjects.AgentInfoProto
AgentInfoConvertor(AgentInfo info)
{
var agentInfoProto = new CommunicatorObjects.AgentInfoProto
{
StackedVectorObservation = { info.stackedVectorObservation },
StoredVectorActions = { info.storedVectorActions },
StoredTextActions = info.storedTextActions,
TextObservation = info.textObservation,
Reward = info.reward,
MaxStepReached = info.maxStepReached,
Done = info.done,
Id = info.id,
};
if (info.memories != null)
{
agentInfoProto.Memories.Add(info.memories);
}
if (info.actionMasks != null)
{
agentInfoProto.ActionMask.AddRange(info.actionMasks);
}
foreach (Texture2D obs in info.visualObservations)
{
agentInfoProto.VisualObservations.Add(
ByteString.CopyFrom(obs.EncodeToPNG())
);
}
return agentInfoProto;
}
///
/// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent
///
/// The BrainInfoProto generated.
/// The BrainParameters.
/// The name of the brain.
/// The type of brain.
public static CommunicatorObjects.BrainParametersProto BrainParametersConvertor(
BrainParameters brainParameters, string name, CommunicatorObjects.BrainTypeProto type)
{
var brainParametersProto = new CommunicatorObjects.BrainParametersProto
{
VectorObservationSize = brainParameters.vectorObservationSize,
NumStackedVectorObservations = brainParameters.numStackedVectorObservations,
VectorActionSize = {brainParameters.vectorActionSize},
VectorActionSpaceType =
(CommunicatorObjects.SpaceTypeProto)brainParameters.vectorActionSpaceType,
BrainName = name,
BrainType = type
};
brainParametersProto.VectorActionDescriptions.AddRange(
brainParameters.vectorActionDescriptions);
foreach (resolution res in brainParameters.cameraResolutions)
{
brainParametersProto.CameraResolutions.Add(
new CommunicatorObjects.ResolutionProto
{
Width = res.width,
Height = res.height,
GrayScale = res.blackAndWhite
});
}
return brainParametersProto;
}
///
/// Sends the brain info. If at least one brain has an agent in need of
/// a decision or if the academy is done, the data is sent via
/// Communicator. Else, a new step is realized. The data can only be
/// sent once all the brains that subscribed to the batcher have tried
/// to send information.
///
/// Brain key.
/// Agent info.
public void SendBrainInfo(
string brainKey, Dictionary agentInfo)
{
// If no communicator is initialized, the Batcher will not transmit
// BrainInfo
if (m_communicator == null)
{
return;
}
// The brain tried called GiveBrainInfo, update m_hasQueried
m_hasQueried[brainKey] = true;
// Populate the currentAgents dictionary
m_currentAgents[brainKey].Clear();
foreach (Agent agent in agentInfo.Keys)
{
m_currentAgents[brainKey].Add(agent);
}
// If at least one agent has data to send, then append data to
// the message and update hasSentState
if (m_currentAgents[brainKey].Count > 0)
{
foreach (Agent agent in m_currentAgents[brainKey])
{
CommunicatorObjects.AgentInfoProto agentInfoProto =
AgentInfoConvertor(agentInfo[agent]);
m_currentUnityRLOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
}
m_hasData[brainKey] = true;
}
// If any agent needs to send data, then the whole message
// must be sent
if (m_hasQueried.Values.All(x => x))
{
if (m_hasData.Values.Any(x => x) || m_academyDone)
{
m_currentUnityRLOutput.GlobalDone = m_academyDone;
SendBatchedMessageHelper();
}
// The message was just sent so we must reset hasSentState and
// triedSendState
foreach (string k in m_currentAgents.Keys)
{
m_hasData[k] = false;
m_hasQueried[k] = false;
}
}
}
///
/// Helper method that sends the curent UnityRLOutput, receives the next UnityInput and
/// Applies the appropriate AgentAction to the agents.
///
void SendBatchedMessageHelper()
{
var input = m_communicator.Exchange(
new CommunicatorObjects.UnityOutput{
RlOutput = m_currentUnityRLOutput
});
m_messagesReceived += 1;
foreach (string k in m_currentUnityRLOutput.AgentInfos.Keys)
{
m_currentUnityRLOutput.AgentInfos[k].Value.Clear();
}
if (input == null)
{
m_command = CommunicatorObjects.CommandProto.Quit;
return;
}
CommunicatorObjects.UnityRLInput rlInput = input.RlInput;
if (rlInput == null)
{
m_command = CommunicatorObjects.CommandProto.Quit;
return;
}
m_command = rlInput.Command;
m_environmentParameters = rlInput.EnvironmentParameters;
m_isTraining = rlInput.IsTraining;
if (rlInput.AgentActions == null)
{
return;
}
foreach (var brainName in rlInput.AgentActions.Keys)
{
if (!m_currentAgents[brainName].Any())
{
continue;
}
if (!rlInput.AgentActions[brainName].Value.Any())
{
continue;
}
for (var i = 0; i < m_currentAgents[brainName].Count(); i++)
{
var agent = m_currentAgents[brainName][i];
var action = rlInput.AgentActions[brainName].Value[i];
agent.UpdateVectorAction(
action.VectorActions.ToArray());
agent.UpdateMemoriesAction(
action.Memories.ToList());
agent.UpdateTextAction(
action.TextActions);
agent.UpdateValueAction(
action.Value);
}
}
}
}
}