using System.Collections.Generic;
using Unity.Barracuda;
using MLAgents.Policies;
namespace MLAgents.Inference
{
///
/// Mapping between the output tensor names and the method that will use the
/// output tensors and the Agents present in the batch to update their action, memories and
/// value estimates.
/// A TensorApplier implements a Dictionary of strings (node names) to an Action.
/// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for
/// the current batch.
///
internal class TensorApplier
{
///
/// A tensor Applier's Execute method takes a tensor and a Dictionary of Agent to AgentInfo.
/// Uses the data contained inside the tensor to modify the state of the Agent. The Tensors
/// are assumed to have the batch size on the first dimension and the agents to be ordered
/// the same way in the dictionary and in the tensor.
///
public interface IApplier
{
///
/// Applies the values in the Tensor to the Agents present in the agentInfos
///
///
/// The Tensor containing the data to be applied to the Agents
///
/// List of Agents Ids that will be updated using the tensor's data
/// Dictionary of AgentId to Actions to be updated
void Apply(TensorProxy tensorProxy, IEnumerable actionIds, Dictionary lastActions);
}
readonly Dictionary m_Dict = new Dictionary();
///
/// Returns a new TensorAppliers object.
///
/// The BrainParameters used to determine what Appliers will be
/// used
/// The seed the Appliers will be initialized with.
/// Tensor allocator
/// Dictionary of AgentInfo.id to memory used to pass to the inference model.
///
public TensorApplier(
BrainParameters bp,
int seed,
ITensorAllocator allocator,
Dictionary> memories,
object barracudaModel = null)
{
if (bp.VectorActionSpaceType == SpaceType.Continuous)
{
m_Dict[TensorNames.ActionOutput] = new ContinuousActionOutputApplier();
}
else
{
m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(bp.VectorActionSize, seed, allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);
if (barracudaModel != null)
{
var model = (Model)barracudaModel;
for (var i = 0; i < model?.memories.Count; i++)
{
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
}
}
///
/// Updates the state of the agents based on the data present in the tensor.
///
/// Enumerable of tensors containing the data.
/// List of Agents Ids that will be updated using the tensor's data
/// Dictionary of AgentId to Actions to be updated
/// One of the tensor does not have an
/// associated applier.
public void ApplyTensors(
IEnumerable tensors, IEnumerable actionIds, Dictionary lastActions)
{
foreach (var tensor in tensors)
{
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(
$"Unknown tensorProxy expected as output : {tensor.name}");
}
m_Dict[tensor.name].Apply(tensor, actionIds, lastActions);
}
}
}
}