using System.Collections.Generic; using System.Linq; using Unity.MLAgents.Inference.Utils; using Unity.MLAgents.Actuators; using Unity.Barracuda; using UnityEngine; namespace Unity.MLAgents.Inference { /// /// The Applier for the Continuous Action output tensor. Tensor is assumed to contain the /// continuous action data of the agents in the batch. /// internal class ContinuousActionOutputApplier : TensorApplier.IApplier { readonly ActionSpec m_ActionSpec; public ContinuousActionOutputApplier(ActionSpec actionSpec) { m_ActionSpec = actionSpec; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; var agentIndex = 0; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; if (lastActions.ContainsKey(agentId)) { var actionBuffer = lastActions[agentId]; if (actionBuffer.IsEmpty()) { actionBuffer = new ActionBuffers(m_ActionSpec); lastActions[agentId] = actionBuffer; } var continuousBuffer = actionBuffer.ContinuousActions; for (var j = 0; j < actionSize; j++) { continuousBuffer[j] = tensorProxy.data[agentIndex, j]; } } agentIndex++; } } } /// /// The Applier for the Discrete Action output tensor. /// internal class DiscreteActionOutputApplier : TensorApplier.IApplier { readonly ActionSpec m_ActionSpec; public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator) { m_ActionSpec = actionSpec; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var agentIndex = 0; var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; if (lastActions.ContainsKey(agentId)) { var actionBuffer = lastActions[agentId]; if (actionBuffer.IsEmpty()) { actionBuffer = new ActionBuffers(m_ActionSpec); lastActions[agentId] = actionBuffer; } var discreteBuffer = actionBuffer.DiscreteActions; for (var j = 0; j < actionSize; j++) { discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j]; } } agentIndex++; } } } /// /// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete /// actions from the logits contained in the tensor. /// internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier { readonly int[] m_ActionSize; readonly Multinomial m_Multinomial; readonly ActionSpec m_ActionSpec; readonly int[] m_StartActionIndices; readonly float[] m_CdfBuffer; public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator) { m_ActionSize = actionSpec.BranchSizes; m_Multinomial = new Multinomial(seed); m_ActionSpec = actionSpec; m_StartActionIndices = Utilities.CumSum(m_ActionSize); // Scratch space for computing the cumulative distribution function. // In order to reuse it, make it the size of the largest branch. var largestBranch = Mathf.Max(m_ActionSize); m_CdfBuffer = new float[largestBranch]; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var agentIndex = 0; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; if (lastActions.ContainsKey(agentId)) { var actionBuffer = lastActions[agentId]; if (actionBuffer.IsEmpty()) { actionBuffer = new ActionBuffers(m_ActionSpec); lastActions[agentId] = actionBuffer; } var discreteBuffer = actionBuffer.DiscreteActions; for (var j = 0; j < m_ActionSize.Length; j++) { ComputeCdf(tensorProxy, agentIndex, m_StartActionIndices[j], m_ActionSize[j]); discreteBuffer[j] = m_Multinomial.Sample(m_CdfBuffer, m_ActionSize[j]); } } agentIndex++; } } /// /// Compute the cumulative distribution function for a given agent's action /// given the log-probabilities. /// The results are stored in m_CdfBuffer, which is the size of the largest action's number of branches. /// /// /// Index of the agent being considered /// Offset into the tensor's channel. /// internal void ComputeCdf(TensorProxy logProbs, int batch, int channelOffset, int branchSize) { // Find the class maximum var maxProb = float.NegativeInfinity; for (var cls = 0; cls < branchSize; ++cls) { maxProb = Mathf.Max(logProbs.data[batch, cls + channelOffset], maxProb); } // Sum the log probabilities and compute CDF var sumProb = 0.0f; for (var cls = 0; cls < branchSize; ++cls) { sumProb += Mathf.Exp(logProbs.data[batch, cls + channelOffset] - maxProb); m_CdfBuffer[cls] = sumProb; } } } /// /// The Applier for the Memory output tensor. Tensor is assumed to contain the new /// memory data of the agents in the batch. /// internal class MemoryOutputApplier : TensorApplier.IApplier { Dictionary> m_Memories; public MemoryOutputApplier( Dictionary> memories) { m_Memories = memories; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var agentIndex = 0; var memorySize = tensorProxy.data.width; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; List memory; if (!m_Memories.TryGetValue(agentId, out memory) || memory.Count < memorySize) { memory = new List(); memory.AddRange(Enumerable.Repeat(0f, memorySize)); } for (var j = 0; j < memorySize; j++) { memory[j] = tensorProxy.data[agentIndex, 0, j, 0]; } m_Memories[agentId] = memory; agentIndex++; } } } internal class BarracudaMemoryOutputApplier : TensorApplier.IApplier { readonly int m_MemoriesCount; readonly int m_MemoryIndex; Dictionary> m_Memories; public BarracudaMemoryOutputApplier( int memoriesCount, int memoryIndex, Dictionary> memories) { m_MemoriesCount = memoriesCount; m_MemoryIndex = memoryIndex; m_Memories = memories; } public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) { var agentIndex = 0; var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1]; for (var i = 0; i < actionIds.Count; i++) { var agentId = actionIds[i]; List memory; if (!m_Memories.TryGetValue(agentId, out memory) || memory.Count < memorySize * m_MemoriesCount) { memory = new List(); memory.AddRange(Enumerable.Repeat(0f, memorySize * m_MemoriesCount)); } for (var j = 0; j < memorySize; j++) { memory[memorySize * m_MemoryIndex + j] = tensorProxy.data[agentIndex, j]; } m_Memories[agentId] = memory; agentIndex++; } } } }