using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Inference { /// /// Mapping between Tensor names and generators. /// A TensorGenerator implements a Dictionary of strings (node names) to an Action. /// The Action take as argument the tensor, the current batch size and a Dictionary of /// Agent to AgentInfo corresponding to the current batch. /// Each Generator reshapes and fills the data of the tensor based of the data of the batch. /// When the TensorProxy is an Input to the model, the shape of the Tensor will be modified /// depending on the current batch size and the data of the Tensor will be filled using the /// Dictionary of Agent to AgentInfo. /// When the TensorProxy is an Output of the model, only the shape of the Tensor will be /// modified using the current batch size. The data will be pre-filled with zeros. /// internal class TensorGenerator { public interface IGenerator { /// /// Modifies the data inside a Tensor according to the information contained in the /// AgentInfos contained in the current batch. /// /// The tensor the data and shape will be modified. /// The number of agents present in the current batch. /// /// List of AgentInfos containing the information that will be used to populate /// the tensor's data. /// void Generate( TensorProxy tensorProxy, int batchSize, IList infos); } readonly Dictionary m_Dict = new Dictionary(); /// /// Returns a new TensorGenerators object. /// /// The seed the Generators will be initialized with. /// Tensor allocator. /// Dictionary of AgentInfo.id to memory for use in the inference model. /// public TensorGenerator( int seed, ITensorAllocator allocator, Dictionary> memories, object barracudaModel = null) { // If model is null, no inference to run and exception is thrown before reaching here. if (barracudaModel == null) { return; } var model = (Model)barracudaModel; // Generator for Inputs m_Dict[TensorNames.BatchSizePlaceholder] = new BatchSizeGenerator(allocator); m_Dict[TensorNames.SequenceLengthPlaceholder] = new SequenceLengthGenerator(allocator); m_Dict[TensorNames.RecurrentInPlaceholder] = new RecurrentInputGenerator(allocator, memories); for (var i = 0; i < model.memories.Count; i++) { m_Dict[model.memories[i].input] = new BarracudaRecurrentInputGenerator(i, allocator, memories); } m_Dict[TensorNames.PreviousActionPlaceholder] = new PreviousActionInputGenerator(allocator); m_Dict[TensorNames.ActionMaskPlaceholder] = new ActionMaskInputGenerator(allocator); m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] = new RandomNormalInputGenerator(seed, allocator); // Generators for Outputs if (model.HasContinuousOutputs()) { m_Dict[model.ContinuousOutputName()] = new BiDimensionalOutputGenerator(allocator); } if (model.HasDiscreteOutputs()) { m_Dict[model.DiscreteOutputName()] = new BiDimensionalOutputGenerator(allocator); } m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); } public void InitializeObservations(List sensors, ITensorAllocator allocator) { // Loop through the sensors on a representative agent. // All vector observations use a shared ObservationGenerator since they are concatenated. // All other observations use a unique ObservationInputGenerator var visIndex = 0; ObservationGenerator vecObsGen = null; for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) { var sensor = sensors[sensorIndex]; var shape = sensor.GetObservationShape(); var rank = shape.Length; ObservationGenerator obsGen = null; string obsGenName = null; switch (rank) { case 1: if (vecObsGen == null) { vecObsGen = new ObservationGenerator(allocator); } obsGen = vecObsGen; obsGenName = TensorNames.VectorObservationPlaceholder; break; case 2: // If the tensor is of rank 2, we use the index of the sensor // to create the name obsGen = new ObservationGenerator(allocator); obsGenName = TensorNames.GetObservationName(sensorIndex); break; case 3: // If the tensor is of rank 3, we use the "visual observation // index", which only counts the rank 3 sensors obsGen = new ObservationGenerator(allocator); obsGenName = TensorNames.GetVisualObservationName(visIndex); visIndex++; break; default: throw new UnityAgentsException( $"Sensor {sensor.GetName()} have an invalid rank {rank}"); } obsGen.AddSensorIndex(sensorIndex); m_Dict[obsGenName] = obsGen; } } /// /// Populates the data of the tensor inputs given the data contained in the current batch /// of agents. /// /// Enumerable of tensors that will be modified. /// The number of agents present in the current batch /// /// List of AgentsInfos and Sensors that contains the /// data that will be used to modify the tensors /// One of the tensor does not have an /// associated generator. public void GenerateTensors( IReadOnlyList tensors, int currentBatchSize, IList infos) { for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) { var tensor = tensors[tensorIndex]; if (!m_Dict.ContainsKey(tensor.name)) { throw new UnityAgentsException( $"Unknown tensorProxy expected as input : {tensor.name}"); } m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos); } } } }