using System.Collections.Generic; using System.Linq; using UnityEngine; using Google.Protobuf; namespace MLAgents { /// /// The batcher is an RL specific class that makes sure that the information each object in /// Unity (Academy and Brains) wants to send to External is appropriately batched together /// and sent only when necessary. /// /// The Batcher will only send a Message to the Communicator when either : /// 1 - The academy is done /// 2 - At least one brain has data to send /// /// At each step, the batcher will keep track of the brains that queried the batcher for that /// step. The batcher can only send the batched data when all the Brains have queried the /// Batcher. /// public class Batcher { /// The default number of agents in the scene private const int NumAgents = 32; /// Keeps track of which brains have data to send on the current step Dictionary m_hasData = new Dictionary(); /// Keeps track of which brains queried the batcher on the current step Dictionary m_hasQueried = new Dictionary(); /// Keeps track of the agents of each brain on the current step Dictionary> m_currentAgents = new Dictionary>(); /// The Communicator of the batcher, sends a message at most once per step Communicator m_communicator; /// The current UnityRLOutput to be sent when all the brains queried the batcher CommunicatorObjects.UnityRLOutput m_currentUnityRLOutput = new CommunicatorObjects.UnityRLOutput(); /// Keeps track of the done flag of the Academy bool m_academyDone; /// Keeps track of last CommandProto sent by External CommunicatorObjects.CommandProto m_command; /// Keeps track of last EnvironmentParametersProto sent by External CommunicatorObjects.EnvironmentParametersProto m_environmentParameters; /// Keeps track of last training mode sent by External bool m_isTraining; /// Keeps track of the number of messages received private ulong m_messagesReceived; /// /// Initializes a new instance of the Batcher class. /// /// The communicator to be used by the batcher. public Batcher(Communicator communicator) { this.m_communicator = communicator; } /// /// Sends the academy parameters through the Communicator. /// Is used by the academy to send the AcademyParameters to the communicator. /// /// The External Initialization Parameters received. /// The Unity Initialization Paramters to be sent. public CommunicatorObjects.UnityRLInitializationInput SendAcademyParameters( CommunicatorObjects.UnityRLInitializationOutput academyParameters) { CommunicatorObjects.UnityInput input; var initializationInput = new CommunicatorObjects.UnityInput(); try { initializationInput = m_communicator.Initialize( new CommunicatorObjects.UnityOutput { RlInitializationOutput = academyParameters }, out input); } catch { throw new UnityAgentsException( "The Communicator was unable to connect. Please make sure the External " + "process is ready to accept communication with Unity."); } var firstRlInput = input.RlInput; m_command = firstRlInput.Command; m_environmentParameters = firstRlInput.EnvironmentParameters; m_isTraining = firstRlInput.IsTraining; return initializationInput.RlInitializationInput; } /// /// Registers the done flag of the academy to the next output to be sent /// to the communicator. /// /// If set to true /// The academy done state will be sent to External at the next Exchange. public void RegisterAcademyDoneFlag(bool done) { m_academyDone = done; } /// /// Gets the command. Is used by the academy to get reset or quit signals. /// /// The current command. public CommunicatorObjects.CommandProto GetCommand() { return m_command; } /// /// Gets the number of messages received so far. Can be used to check for new messages. /// /// The number of messages received since start of the simulation public ulong GetNumberMessageReceived() { return m_messagesReceived; } /// /// Gets the environment parameters. Is used by the academy to update /// the environment parameters. /// /// The environment parameters. public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters() { return m_environmentParameters; } /// /// Gets the last training_mode flag External sent /// /// true, if training mode is requested, false otherwise. public bool GetIsTraining() { return m_isTraining; } /// /// Adds the brain to the list of brains which will be sending information to External. /// /// Brain key. public void SubscribeBrain(string brainKey) { m_hasQueried[brainKey] = false; m_hasData[brainKey] = false; m_currentAgents[brainKey] = new List(NumAgents); m_currentUnityRLOutput.AgentInfos.Add( brainKey, new CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto()); } /// /// Converts a AgentInfo to a protobuffer generated AgentInfoProto /// /// The protobuf verison of the AgentInfo. /// The AgentInfo to convert. public static CommunicatorObjects.AgentInfoProto AgentInfoConvertor(AgentInfo info) { var agentInfoProto = new CommunicatorObjects.AgentInfoProto { StackedVectorObservation = { info.stackedVectorObservation }, StoredVectorActions = { info.storedVectorActions }, StoredTextActions = info.storedTextActions, TextObservation = info.textObservation, Reward = info.reward, MaxStepReached = info.maxStepReached, Done = info.done, Id = info.id, }; if (info.memories != null) { agentInfoProto.Memories.Add(info.memories); } if (info.actionMasks != null) { agentInfoProto.ActionMask.AddRange(info.actionMasks); } foreach (Texture2D obs in info.visualObservations) { agentInfoProto.VisualObservations.Add( ByteString.CopyFrom(obs.EncodeToPNG()) ); } return agentInfoProto; } /// /// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent /// /// The BrainInfoProto generated. /// The BrainParameters. /// The name of the brain. /// The type of brain. public static CommunicatorObjects.BrainParametersProto BrainParametersConvertor( BrainParameters brainParameters, string name, CommunicatorObjects.BrainTypeProto type) { var brainParametersProto = new CommunicatorObjects.BrainParametersProto { VectorObservationSize = brainParameters.vectorObservationSize, NumStackedVectorObservations = brainParameters.numStackedVectorObservations, VectorActionSize = {brainParameters.vectorActionSize}, VectorActionSpaceType = (CommunicatorObjects.SpaceTypeProto)brainParameters.vectorActionSpaceType, BrainName = name, BrainType = type }; brainParametersProto.VectorActionDescriptions.AddRange( brainParameters.vectorActionDescriptions); foreach (resolution res in brainParameters.cameraResolutions) { brainParametersProto.CameraResolutions.Add( new CommunicatorObjects.ResolutionProto { Width = res.width, Height = res.height, GrayScale = res.blackAndWhite }); } return brainParametersProto; } /// /// Sends the brain info. If at least one brain has an agent in need of /// a decision or if the academy is done, the data is sent via /// Communicator. Else, a new step is realized. The data can only be /// sent once all the brains that subscribed to the batcher have tried /// to send information. /// /// Brain key. /// Agent info. public void SendBrainInfo( string brainKey, Dictionary agentInfo) { // If no communicator is initialized, the Batcher will not transmit // BrainInfo if (m_communicator == null) { return; } // The brain tried called GiveBrainInfo, update m_hasQueried m_hasQueried[brainKey] = true; // Populate the currentAgents dictionary m_currentAgents[brainKey].Clear(); foreach (Agent agent in agentInfo.Keys) { m_currentAgents[brainKey].Add(agent); } // If at least one agent has data to send, then append data to // the message and update hasSentState if (m_currentAgents[brainKey].Count > 0) { foreach (Agent agent in m_currentAgents[brainKey]) { CommunicatorObjects.AgentInfoProto agentInfoProto = AgentInfoConvertor(agentInfo[agent]); m_currentUnityRLOutput.AgentInfos[brainKey].Value.Add(agentInfoProto); } m_hasData[brainKey] = true; } // If any agent needs to send data, then the whole message // must be sent if (m_hasQueried.Values.All(x => x)) { if (m_hasData.Values.Any(x => x) || m_academyDone) { m_currentUnityRLOutput.GlobalDone = m_academyDone; SendBatchedMessageHelper(); } // The message was just sent so we must reset hasSentState and // triedSendState foreach (string k in m_currentAgents.Keys) { m_hasData[k] = false; m_hasQueried[k] = false; } } } /// /// Helper method that sends the curent UnityRLOutput, receives the next UnityInput and /// Applies the appropriate AgentAction to the agents. /// void SendBatchedMessageHelper() { var input = m_communicator.Exchange( new CommunicatorObjects.UnityOutput{ RlOutput = m_currentUnityRLOutput }); m_messagesReceived += 1; foreach (string k in m_currentUnityRLOutput.AgentInfos.Keys) { m_currentUnityRLOutput.AgentInfos[k].Value.Clear(); } if (input == null) { m_command = CommunicatorObjects.CommandProto.Quit; return; } CommunicatorObjects.UnityRLInput rlInput = input.RlInput; if (rlInput == null) { m_command = CommunicatorObjects.CommandProto.Quit; return; } m_command = rlInput.Command; m_environmentParameters = rlInput.EnvironmentParameters; m_isTraining = rlInput.IsTraining; if (rlInput.AgentActions == null) { return; } foreach (var brainName in rlInput.AgentActions.Keys) { if (!m_currentAgents[brainName].Any()) { continue; } if (!rlInput.AgentActions[brainName].Value.Any()) { continue; } for (var i = 0; i < m_currentAgents[brainName].Count(); i++) { var agent = m_currentAgents[brainName][i]; var action = rlInput.AgentActions[brainName].Value[i]; agent.UpdateVectorAction( action.VectorActions.ToArray()); agent.UpdateMemoriesAction( action.Memories.ToList()); agent.UpdateTextAction( action.TextActions); agent.UpdateValueAction( action.Value); } } } } }