Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

115 行
5.2 KiB

using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Inference
{
/// <summary>
/// Mapping between the output tensor names and the method that will use the
/// output tensors and the Agents present in the batch to update their action, memories and
/// value estimates.
/// A TensorApplier implements a Dictionary of strings (node names) to an Action.
/// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for
/// the current batch.
/// </summary>
internal class TensorApplier
{
/// <summary>
/// A tensor Applier's Execute method takes a tensor and a Dictionary of Agent to AgentInfo.
/// Uses the data contained inside the tensor to modify the state of the Agent. The Tensors
/// are assumed to have the batch size on the first dimension and the agents to be ordered
/// the same way in the dictionary and in the tensor.
/// </summary>
public interface IApplier
{
/// <summary>
/// Applies the values in the Tensor to the Agents present in the agentInfos
/// </summary>
/// <param name="tensorProxy">
/// The Tensor containing the data to be applied to the Agents
/// </param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
}
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();
/// <summary>
/// Returns a new TensorAppliers object.
/// </summary>
/// <param name="actionSpec"> Description of the actions for the Agent.</param>
/// <param name="seed"> The seed the Appliers will be initialized with.</param>
/// <param name="allocator"> Tensor allocator</param>
/// <param name="memories">Dictionary of AgentInfo.id to memory used to pass to the inference model.</param>
/// <param name="barracudaModel"></param>
public TensorApplier(
ActionSpec actionSpec,
int seed,
ITensorAllocator allocator,
Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
if (!model.SupportsContinuousAndDiscrete())
{
actionSpec.CheckAllContinuousOrDiscrete();
}
if (actionSpec.NumContinuousActions > 0)
{
var tensorName = model.ContinuousOutputName();
m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec);
}
if (actionSpec.NumDiscreteActions > 0)
{
var tensorName = model.DiscreteOutputName();
var modelVersion = model.GetVersion();
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator);
}
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
}
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);
for (var i = 0; i < model?.memories.Count; i++)
{
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
}
/// <summary>
/// Updates the state of the agents based on the data present in the tensor.
/// </summary>
/// <param name="tensors"> Enumerable of tensors containing the data.</param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IReadOnlyList<TensorProxy> tensors, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
{
var tensor = tensors[tensorIndex];
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(
$"Unknown tensorProxy expected as output : {tensor.name}");
}
m_Dict[tensor.name].Apply(tensor, actionIds, lastActions);
}
}
}
}