using System.Collections.Generic; using System.Linq; using System; using UnityEngine; namespace MLAgents { /// /// The batcher is an RL specific class that makes sure that the information each object in /// Unity (Academy and Brains) wants to send to External is appropriately batched together /// and sent only when necessary. /// /// The Batcher will only send a Message to the Communicator when either : /// 1 - The academy is done /// 2 - At least one brain has data to send /// /// At each step, the batcher will keep track of the brains that queried the batcher for that /// step. The batcher can only send the batched data when all the Brains have queried the /// Batcher. /// public class Batcher { /// The default number of agents in the scene private const int k_NumAgents = 32; /// Keeps track of which brains have data to send on the current step Dictionary m_HasData = new Dictionary(); /// Keeps track of which brains queried the batcher on the current step Dictionary m_HasQueried = new Dictionary(); /// Keeps track of the agents of each brain on the current step Dictionary> m_CurrentAgents = new Dictionary>(); /// The Communicator of the batcher, sends a message at most once per step ICommunicator m_Communicator; /// The current UnityRLOutput to be sent when all the brains queried the batcher CommunicatorObjects.UnityRLOutputProto m_CurrentUnityRlOutput = new CommunicatorObjects.UnityRLOutputProto(); /// Keeps track of last CommandProto sent by External CommunicatorObjects.CommandProto m_Command; /// Keeps track of last EnvironmentParametersProto sent by External CommunicatorObjects.EnvironmentParametersProto m_EnvironmentParameters; /// Keeps track of last training mode sent by External bool m_IsTraining; /// Keeps track of the number of messages received private ulong m_MessagesReceived; /// /// Initializes a new instance of the Batcher class. /// /// The communicator to be used by the batcher. public Batcher(ICommunicator communicator) { m_Communicator = communicator; } /// /// Sends the academy parameters through the Communicator. /// Is used by the academy to send the AcademyParameters to the communicator. /// /// The External Initialization Parameters received. /// The Unity Initialization Parameters to be sent. public CommunicatorObjects.UnityRLInitializationInputProto SendAcademyParameters( CommunicatorObjects.UnityRLInitializationOutputProto academyParameters) { CommunicatorObjects.UnityInputProto input; var initializationInput = new CommunicatorObjects.UnityInputProto(); try { initializationInput = m_Communicator.Initialize( new CommunicatorObjects.UnityOutputProto { RlInitializationOutput = academyParameters }, out input); } catch { var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " + "process is ready to accept communication with Unity."; // Check for common error condition and add details to the exception message. var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY"); var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY"); if (httpProxy != null || httpsProxy != null) { exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" + "environment variables and try again."; } throw new UnityAgentsException(exceptionMessage); } var firstRlInput = input.RlInput; m_Command = firstRlInput.Command; m_EnvironmentParameters = firstRlInput.EnvironmentParameters; m_IsTraining = firstRlInput.IsTraining; return initializationInput.RlInitializationInput; } /// /// Gets the command. Is used by the academy to get reset or quit signals. /// /// The current command. public CommunicatorObjects.CommandProto GetCommand() { return m_Command; } /// /// Gets the number of messages received so far. Can be used to check for new messages. /// /// The number of messages received since start of the simulation public ulong GetNumberMessageReceived() { return m_MessagesReceived; } /// /// Gets the environment parameters. Is used by the academy to update /// the environment parameters. /// /// The environment parameters. public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters() { return m_EnvironmentParameters; } /// /// Gets the last training_mode flag External sent /// /// true, if training mode is requested, false otherwise. public bool GetIsTraining() { return m_IsTraining; } /// /// Adds the brain to the list of brains which will be sending information to External. /// /// Brain key. public void SubscribeBrain(string brainKey) { m_HasQueried[brainKey] = false; m_HasData[brainKey] = false; m_CurrentAgents[brainKey] = new List(k_NumAgents); m_CurrentUnityRlOutput.AgentInfos.Add( brainKey, new CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto()); } /// /// Sends the brain info. If at least one brain has an agent in need of /// a decision or if the academy is done, the data is sent via /// Communicator. Else, a new step is realized. The data can only be /// sent once all the brains that subscribed to the batcher have tried /// to send information. /// /// Brain key. /// Agent info. public void SendBrainInfo( string brainKey, Dictionary agentInfo) { // If no communicator is initialized, the Batcher will not transmit // BrainInfo if (m_Communicator == null) { return; } // The brain tried called GiveBrainInfo, update m_hasQueried m_HasQueried[brainKey] = true; // Populate the currentAgents dictionary m_CurrentAgents[brainKey].Clear(); foreach (var agent in agentInfo.Keys) { m_CurrentAgents[brainKey].Add(agent); } // If at least one agent has data to send, then append data to // the message and update hasSentState if (m_CurrentAgents[brainKey].Count > 0) { foreach (var agent in m_CurrentAgents[brainKey]) { var agentInfoProto = agentInfo[agent].ToProto(); m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto); // Avoid visual obs memory leak. This should be called AFTER we are done with the visual obs. // e.g. after recording them to demo and using them for inference. agentInfo[agent].ClearVisualObs(); } m_HasData[brainKey] = true; } // If any agent needs to send data, then the whole message // must be sent if (m_HasQueried.Values.All(x => x)) { if (m_HasData.Values.Any(x => x)) { SendBatchedMessageHelper(); } // The message was just sent so we must reset hasSentState and // triedSendState foreach (var k in m_CurrentAgents.Keys) { m_HasData[k] = false; m_HasQueried[k] = false; } } } /// /// Helper method that sends the current UnityRLOutput, receives the next UnityInput and /// Applies the appropriate AgentAction to the agents. /// void SendBatchedMessageHelper() { var input = m_Communicator.Exchange( new CommunicatorObjects.UnityOutputProto { RlOutput = m_CurrentUnityRlOutput }); m_MessagesReceived += 1; foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) { m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); } if (input == null) { m_Command = CommunicatorObjects.CommandProto.Quit; return; } var rlInput = input.RlInput; if (rlInput == null) { m_Command = CommunicatorObjects.CommandProto.Quit; return; } m_Command = rlInput.Command; m_EnvironmentParameters = rlInput.EnvironmentParameters; m_IsTraining = rlInput.IsTraining; if (rlInput.AgentActions == null) { return; } foreach (var brainName in rlInput.AgentActions.Keys) { if (!m_CurrentAgents[brainName].Any()) { continue; } if (!rlInput.AgentActions[brainName].Value.Any()) { continue; } for (var i = 0; i < m_CurrentAgents[brainName].Count; i++) { var agent = m_CurrentAgents[brainName][i]; var action = rlInput.AgentActions[brainName].Value[i]; agent.UpdateVectorAction(action.VectorActions.ToArray()); agent.UpdateMemoriesAction(action.Memories.ToList()); agent.UpdateTextAction(action.TextActions); agent.UpdateValueAction(action.Value); agent.UpdateCustomAction(action.CustomAction); } } } } }