using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents.Actuators; using System.Linq; using Unity.MLAgents.Inference.Utils; using UnityEngine; namespace Unity.MLAgents.Inference { /// /// Mapping between the output tensor names and the method that will use the /// output tensors and the Agents present in the batch to update their action, memories and /// value estimates. /// A TensorApplier implements a Dictionary of strings (node names) to an Action. /// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for /// the current batch. /// internal class TrainingForwardTensorApplier { readonly Dictionary m_Dict = new Dictionary(); /// /// Returns a new TensorAppliers object. /// /// Description of the actions for the Agent. /// The seed the Appliers will be initialized with. /// Tensor allocator /// Dictionary of AgentInfo.id to memory used to pass to the inference model. /// public TrainingForwardTensorApplier( ActionSpec actionSpec, int seed, ITensorAllocator allocator, object barracudaModel = null) { // If model is null, no inference to run and exception is thrown before reaching here. if (barracudaModel == null) { return; } if (actionSpec.NumContinuousActions > 0) { throw new System.Exception("Cannot do continuous actions"); } if (actionSpec.NumDiscreteActions != 1) { throw new System.Exception("Cannot do multi discrete actions, only single discrete"); } var model = (Model)barracudaModel; m_Dict[TensorNames.TrainingOutput] = new MaxActionOutputApplier(actionSpec, seed, allocator); } /// /// Updates the state of the agents based on the data present in the tensor. /// /// Enumerable of tensors containing the data. /// List of Agents Ids that will be updated using the tensor's data /// Dictionary of AgentId to Actions to be updated /// One of the tensor does not have an /// associated applier. public void ApplyTensors( IReadOnlyList tensors, IList actionIds, Dictionary lastActions) { for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) { var tensor = tensors[tensorIndex]; if (!m_Dict.ContainsKey(tensor.name)) { throw new UnityAgentsException( $"Unknown tensorProxy expected as output : {tensor.name}"); } m_Dict[tensor.name].Apply(tensor, actionIds, lastActions); } } } internal class MaxActionOutputApplier : TensorApplier.IApplier { readonly ActionSpec m_ActionSpec; public MaxActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator) { m_ActionSpec = actionSpec; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var agentIndex = 0; var actionSpaceSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; if (lastActions.ContainsKey(agentId)) { var actionBuffer = lastActions[agentId]; if (actionBuffer.IsEmpty()) { actionBuffer = new ActionBuffers(m_ActionSpec); lastActions[agentId] = actionBuffer; } var discreteBuffer = actionBuffer.DiscreteActions; var maxIndex = 0; var maxValue = 0; for (var j = 0; j < actionSpaceSize; j++) { var value = (int)tensorProxy.data[agentIndex, j]; if (value > maxValue) { maxIndex = j; } } var actionSize = discreteBuffer.Length; discreteBuffer[0] = maxIndex; } agentIndex++; } } } }