using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents.Actuators; namespace Unity.MLAgents.Inference { /// /// Mapping between the output tensor names and the method that will use the /// output tensors and the Agents present in the batch to update their action, memories and /// value estimates. /// A TensorApplier implements a Dictionary of strings (node names) to an Action. /// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for /// the current batch. /// internal class TensorApplier { /// /// A tensor Applier's Execute method takes a tensor and a Dictionary of Agent to AgentInfo. /// Uses the data contained inside the tensor to modify the state of the Agent. The Tensors /// are assumed to have the batch size on the first dimension and the agents to be ordered /// the same way in the dictionary and in the tensor. /// public interface IApplier { /// /// Applies the values in the Tensor to the Agents present in the agentInfos /// /// /// The Tensor containing the data to be applied to the Agents /// /// List of Agents Ids that will be updated using the tensor's data /// Dictionary of AgentId to Actions to be updated void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions); } readonly Dictionary m_Dict = new Dictionary(); /// /// Returns a new TensorAppliers object. /// /// Description of the actions for the Agent. /// The seed the Appliers will be initialized with. /// Tensor allocator /// Dictionary of AgentInfo.id to memory used to pass to the inference model. /// public TensorApplier( ActionSpec actionSpec, int seed, ITensorAllocator allocator, Dictionary> memories, object barracudaModel = null) { // If model is null, no inference to run and exception is thrown before reaching here. if (barracudaModel == null) { return; } var model = (Model)barracudaModel; if (!model.SupportsContinuousAndDiscrete()) { actionSpec.CheckAllContinuousOrDiscrete(); } if (actionSpec.NumContinuousActions > 0) { var tensorName = model.ContinuousOutputName(); m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec); } var modelVersion = model.GetVersion(); if (actionSpec.NumDiscreteActions > 0) { var tensorName = model.DiscreteOutputName(); if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) { m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator); } if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) { m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator); } } m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories); if (modelVersion < (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) { for (var i = 0; i < model?.memories.Count; i++) { m_Dict[model.memories[i].output] = new BarracudaMemoryOutputApplier(model.memories.Count, i, memories); } } } /// /// Updates the state of the agents based on the data present in the tensor. /// /// Enumerable of tensors containing the data. /// List of Agents Ids that will be updated using the tensor's data /// Dictionary of AgentId to Actions to be updated /// One of the tensor does not have an /// associated applier. public void ApplyTensors( IReadOnlyList tensors, IList actionIds, Dictionary lastActions) { for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) { var tensor = tensors[tensorIndex]; if (!m_Dict.ContainsKey(tensor.name)) { throw new UnityAgentsException( $"Unknown tensorProxy expected as output : {tensor.name}"); } m_Dict[tensor.name].Apply(tensor, actionIds, lastActions); } } } }