using UnityEngine; using System.Collections.Generic; namespace MLAgents { /// /// The Remote Policy only works when training. /// When training your Agents, the RemotePolicy will be controlled by Python. /// public class RemotePolicy : IPolicy { string m_BehaviorName; protected ICommunicator m_Communicator; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; /// public RemotePolicy( BrainParameters brainParameters, string behaviorName) { m_BehaviorName = behaviorName; var aca = Object.FindObjectOfType(); aca.LazyInitialization(); m_Communicator = aca.Communicator; aca.Communicator.SubscribeBrain(m_BehaviorName, brainParameters); } /// public void RequestDecision(Agent agent) { #if DEBUG ValidateAgentSensorShapes(agent); #endif m_Communicator?.PutObservations(m_BehaviorName, agent); } /// public void DecideAction() { m_Communicator?.DecideBatch(); } /// /// Check that the Agent Sensors are the same shape as the the other Agents using the same Brain. /// If this is the first Agent being checked, its Sensor sizes will be saved. /// /// The Agent to check void ValidateAgentSensorShapes(Agent agent) { if (m_SensorShapes == null) { m_SensorShapes = new List(agent.sensors.Count); // First agent, save the sensor sizes foreach (var sensor in agent.sensors) { m_SensorShapes.Add(sensor.GetFloatObservationShape()); } } else { // Check for compatibility with the other Agents' Sensors // TODO make sure this only checks once per agent Debug.Assert(m_SensorShapes.Count == agent.sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {agent.sensors.Count}"); for (var i = 0; i < m_SensorShapes.Count; i++) { var cachedShape = m_SensorShapes[i]; var sensorShape = agent.sensors[i].GetFloatObservationShape(); Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match."); for (var j = 0; j < cachedShape.Length; j++) { Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes much match."); } } } } public void Dispose() { } } }