浏览代码

Barracuda inference for hybrid actions (#4611)

* TensorApplier.IApplier takes ActionBuffers instead of float[] as input argument

* Model output format changed
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
5e5ff19b
共有 36 个文件被更改,包括 2958 次插入238 次删除
  1. 2
      .yamato/com.unity.ml-agents-performance.yml
  2. 2
      .yamato/com.unity.ml-agents-test.yml
  3. 2
      .yamato/compressed-sensor-test.yml
  4. 2
      .yamato/gym-interface-test.yml
  5. 2
      .yamato/protobuf-generation-test.yml
  6. 2
      .yamato/python-ll-api-test.yml
  7. 2
      .yamato/standalone-build-test.yml
  8. 2
      .yamato/training-int-tests.yml
  9. 8
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  10. 17
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  11. 5
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  12. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  13. 136
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  14. 8
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  15. 35
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  16. 26
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  17. 15
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  18. 20
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  19. 12
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  20. 74
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  21. 62
      com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
  22. 190
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  23. 2
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  24. 2
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
  25. 141
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  26. 11
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta
  27. 1001
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
  28. 14
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta
  29. 867
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx
  30. 14
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx.meta
  31. 462
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx
  32. 14
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx.meta
  33. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn
  34. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn
  35. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  36. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta

2
.yamato/com.unity.ml-agents-performance.yml


- ./utr --suite=editor --platform=StandaloneOSX --editor-location=.Editor --testproject=DevProject --artifacts_path=build/test-results --report-performance-data --performance-project-id=com.unity.ml-agents --zero-tests-are-ok=1
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/com.unity.ml-agents-test.yml


triggers:
cancel_old_ci: true
{% if platform.name == "mac" %}
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/compressed-sensor-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/gym-interface-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/protobuf-generation-test.yml


git diff -- :/ ":(exclude,top)$CS_PROTO_PATH/*.meta" > artifacts/proto.patch; exit $GIT_ERR; }
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "protobuf-definitions/**" OR

2
.yamato/python-ll-api-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/standalone-build-test.yml


- python -u -m ml-agents.tests.yamato.standalone_build_tests --scene=Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureCompressed.unity
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/training-int-tests.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

8
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


System.Array.Clear(Array, Offset, Length);
}
/// <summary>
/// Check if the segment is empty.
/// </summary>
public bool IsEmpty()
{
return Array.Length == 0;
}
/// <inheritdoc/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{

17
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


}
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with <see cref="ActionSpec"/>. All values are initialized to zeros.
/// /// </summary>
/// <param name="actionSpec">The <see cref="ActionSpec"/> to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSpec actionSpec)
: this(new ActionSegment<float>(new float[actionSpec.NumContinuousActions]),
new ActionSegment<int>(new int[actionSpec.NumDiscreteActions]))
{ }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with ActionSpec and all actions stored as a float array.
/// </summary>
/// <param name="actionSpec"><see cref="ActionSpec"/> of the <see cref="ActionBuffers"/></param>

{
ContinuousActions.Clear();
DiscreteActions.Clear();
}
/// <summary>
/// Check if the <see cref="ActionBuffers"/> is empty.
/// </summary>
public bool IsEmpty()
{
return ContinuousActions.IsEmpty() && DiscreteActions.IsEmpty();
}
/// <inheritdoc/>

5
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


{
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions,
DiscreteBranchSizes = { actionSpec.BranchSizes },
if (actionSpec.BranchSizes != null)
{
actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes);
}
brainParametersProto.ActionSpec = actionSpecProto;
var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions;

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Inference.Utils;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;

/// </summary>
internal class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
readonly ActionSpec m_ActionSpec;
public ContinuousActionOutputApplier(ActionSpec actionSpec)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;

{
var actionValue = lastActions[agentId];
if (actionValue == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionValue = new float[actionSize];
lastActions[agentId] = actionValue;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var continuousBuffer = actionBuffer.ContinuousActions;
actionValue[j] = tensorProxy.data[agentIndex, j];
continuousBuffer[j] = tensorProxy.data[agentIndex, j];
}
}
agentIndex++;

readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;
readonly ITensorAllocator m_Allocator;
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(int[] actionSize, int seed, ITensorAllocator allocator)
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
m_ActionSize = actionSize;
m_ActionSize = actionSpec.BranchSizes;
m_ActionSpec = actionSpec;
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
//var tensorDataProbabilities = tensorProxy.Data as float[,];
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();

{
if (lastActions.ContainsKey(agentId))
{
var actionVal = lastActions[agentId];
if (actionVal == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionVal = new float[m_ActionSize.Length];
lastActions[agentId] = actionVal;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var discreteBuffer = actionBuffer.DiscreteActions;
actionVal[j] = actionValues[agentIndex, j];
discreteBuffer[j] = (int)actionValues[agentIndex, j];
}
}
agentIndex++;

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

136
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// </summary>
internal class BarracudaModelParamLoader
{
enum ModelActionType
{
Unknown,
Discrete,
Continuous
}
const long k_ApiVersion = 2;
/// <summary>

foreach (var input in model.inputs)
{
if (input.shape.Length == 4)
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
count++;
}
}

return names.ToArray();
}
names.Add(TensorNames.ActionOutput);
if (model.HasContinuousOutputs())
{
names.Add(model.ContinuousOutputName());
}
if (model.HasDiscreteOutputs())
{
names.Add(model.DiscreteOutputName());
}
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)

return failedModelChecks;
}
foreach (var constantName in TensorNames.RequiredConstants)
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
var tensor = model.GetTensorByName(constantName);
if (tensor == null)
{
failedModelChecks.Add($"Required constant \"{constantName}\" was not found in the model file.");
return failedModelChecks;
}
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
return failedModelChecks;
var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
var isContinuousInt = (int)model.GetTensorByName(TensorNames.IsContinuousControl)[0];
var isContinuous = GetActionType(isContinuousInt);
var actionSize = (int)model.GetTensorByName(TensorNames.ActionOutputShape)[0];
var modelApiVersion = (int)modelApiVersionTensor[0];
if (modelApiVersion == -1)
{
failedModelChecks.Add(

return failedModelChecks;
}
var modelDiscreteActionSize = isContinuous == ModelActionType.Discrete ? actionSize : 0;
var modelContinuousActionSize = isContinuous == ModelActionType.Continuous ? actionSize : 0;
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
return failedModelChecks;
}
var memorySize = (int)memorySizeTensor[0];
{TensorNames.IsContinuousControl, isContinuousInt},
{TensorNames.ActionOutputShape, actionSize}
CheckInputTensorPresence(model, brainParameters, memorySize, isContinuous, sensorComponents)
CheckInputTensorPresence(model, brainParameters, memorySize, sensorComponents)
CheckOutputTensorPresence(model, memorySize))
;
CheckOutputTensorPresence(model, memorySize)
);
CheckOutputTensorShape(model, brainParameters, actuatorComponents, isContinuous, modelContinuousActionSize, modelDiscreteActionSize)
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
/// <summary>
/// Converts the integer value in the model corresponding to the type of control to a
/// ModelActionType.
/// </summary>
/// <param name="isContinuousInt">
/// The integer value in the model indicating the type of control
/// </param>
/// <returns>The equivalent ModelActionType</returns>
static ModelActionType GetActionType(int isContinuousInt)
{
ModelActionType isContinuous;
switch (isContinuousInt)
{
case 0:
isContinuous = ModelActionType.Discrete;
break;
case 1:
isContinuous = ModelActionType.Continuous;
break;
default:
isContinuous = ModelActionType.Unknown;
break;
}
return isContinuous;
}
/// <summary>
/// Given a Dictionary of node names to int values, create checks if the values have the

Model model,
BrainParameters brainParameters,
int memory,
ModelActionType isContinuous,
SensorComponent[] sensorComponents
)
{

(!tensorsNames.Contains(TensorNames.VectorObservationPlaceholder)))
{
failedModelChecks.Add(
"The model does not contain a Vector Observation Placeholder Input. " +
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.");
}

}
// If the model uses discrete control but does not have an input for action masks
if (isContinuous == ModelActionType.Discrete)
if (model.HasDiscreteOutputs())
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{

{
var failedModelChecks = new List<string>();
// If there is no Action Output.
if (!model.outputs.Contains(TensorNames.ActionOutput))
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain an Action Output Node.");
}

BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
{
// TODO: Update this check after intergrating ActionSpec into BrainParameters
var numberActionsBp = brainParameters.VectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];
if (numberActionsBp != numberActionsT)

static IEnumerable<string> CheckOutputTensorShape(
Model model,
BrainParameters brainParameters,
ActuatorComponent[] actuatorComponents,
ModelActionType isContinuous,
int modelContinuousActionSize, int modelSumDiscreteBranchSizes)
ActuatorComponent[] actuatorComponents)
if (isContinuous == ModelActionType.Unknown)
{
failedModelChecks.Add("Cannot infer type of Control from the provided model.");
return failedModelChecks;
}
if (isContinuous == ModelActionType.Continuous &&
brainParameters.VectorActionSpaceType != SpaceType.Continuous)
{
failedModelChecks.Add(
"Model has been trained using Continuous Control but the Brain Parameters " +
"suggest Discrete Control.");
return failedModelChecks;
}
if (isContinuous == ModelActionType.Discrete &&
brainParameters.VectorActionSpaceType != SpaceType.Discrete)
// Check the presence of action output shape
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
failedModelChecks.Add(
"Model has been trained using Discrete Control but the Brain Parameters " +
"suggest Continuous Control.");
failedModelChecks.Add("The model does not contain an Action Output Shape Node.");
// This will need to change a bit for hybrid action spaces.
if (isContinuous == ModelActionType.Continuous)
if (model.HasContinuousOutputs())
tensorTester[TensorNames.ActionOutput] = CheckContinuousActionOutputShape;
tensorTester[model.ContinuousOutputName()] = CheckContinuousActionOutputShape;
else
if (model.HasDiscreteOutputs())
tensorTester[TensorNames.ActionOutput] = CheckDiscreteActionOutputShape;
tensorTester[model.DiscreteOutputName()] = CheckDiscreteActionOutputShape;
var modelContinuousActionSize = model.ContinuousOutputSize();
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
foreach (var name in model.outputs)
{
if (tensorTester.ContainsKey(name))

8
com.unity.ml-agents/Runtime/Inference/ModelRunner.cs


internal class ModelRunner
{
List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>();
Dictionary<int, float[]> m_LastActionsReceived = new Dictionary<int, float[]>();
Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>();
List<int> m_OrderedAgentsRequestingDecisions = new List<int>();
ITensorAllocator m_TensorAllocator;

if (!m_LastActionsReceived.ContainsKey(info.episodeId))
{
m_LastActionsReceived[info.episodeId] = null;
m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty;
}
if (info.done)
{

return m_Model == other && m_InferenceDevice == otherInferenceDevice;
}
public float[] GetAction(int agentId)
public ActionBuffers GetAction(int agentId)
return null;
return ActionBuffers.Empty;
}
}
}

35
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


/// </param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions);
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
}
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();

Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
actionSpec.CheckNotHybrid();
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
if (model.UseDeprecated())
{
actionSpec.CheckNotHybrid();
}
m_Dict[TensorNames.ActionOutput] = new ContinuousActionOutputApplier();
var tensorName = model.ContinuousOutputName();
m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec);
else
if (actionSpec.NumDiscreteActions > 0)
m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(actionSpec.BranchSizes, seed, allocator);
var tensorName = model.DiscreteOutputName();
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
if (barracudaModel != null)
for (var i = 0; i < model?.memories.Count; i++)
var model = (Model)barracudaModel;
for (var i = 0; i < model?.memories.Count; i++)
{
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
foreach (var tensor in tensors)
{

26
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =
new BatchSizeGenerator(allocator);

new RecurrentInputGenerator(allocator, memories);
if (barracudaModel != null)
for (var i = 0; i < model.memories.Count; i++)
var model = (Model)barracudaModel;
for (var i = 0; i < model.memories.Count; i++)
{
m_Dict[model.memories[i].input] =
new BarracudaRecurrentInputGenerator(i, allocator, memories);
}
m_Dict[model.memories[i].input] =
new BarracudaRecurrentInputGenerator(i, allocator, memories);
}
m_Dict[TensorNames.PreviousActionPlaceholder] =

// Generators for Outputs
m_Dict[TensorNames.ActionOutput] = new BiDimensionalOutputGenerator(allocator);
if (model.HasContinuousOutputs())
{
m_Dict[model.ContinuousOutputName()] = new BiDimensionalOutputGenerator(allocator);
}
if (model.HasDiscreteOutputs())
{
m_Dict[model.DiscreteOutputName()] = new BiDimensionalOutputGenerator(allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator);
m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator);
}

15
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string recurrentOutputC = "recurrent_out_c";
public const string MemorySize = "memory_size";
public const string VersionNumber = "version_number";
public const string IsContinuousControl = "is_continuous_control";
public const string ActionOutputShape = "action_output_shape";
public const string ActionOutput = "action";
public const string ContinuousActionOutputShape = "continuous_action_output_shape";
public const string DiscreteActionOutputShape = "discrete_action_output_shape";
public const string ContinuousActionOutput = "continuous_actions";
public const string DiscreteActionOutput = "discrete_actions";
public static readonly string[] RequiredConstants =
{
VersionNumber, MemorySize, IsContinuousControl, ActionOutputShape
};
// Deprecated TensorNames entries for backward compatibility
public const string IsContinuousControlDeprecated = "is_continuous_control";
public const string ActionOutputDeprecated = "action";
public const string ActionOutputShapeDeprecated = "action_output_shape";
}
}

20
com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs


/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
/// </summary>
List<int[]> m_SensorShapes;
SpaceType m_SpaceType;
ActionSpec m_ActionSpec;
/// <inheritdoc />
public BarracudaPolicy(

{
var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice);
m_ModelRunner = modelRunner;
actionSpec.CheckNotHybrid();
m_SpaceType = actionSpec.NumContinuousActions > 0 ? SpaceType.Continuous : SpaceType.Discrete;
m_ActionSpec = actionSpec;
}
/// <inheritdoc />

/// <inheritdoc />
public ref readonly ActionBuffers DecideAction()
{
m_ModelRunner?.DecideBatch();
var actions = m_ModelRunner?.GetAction(m_AgentId);
if (m_SpaceType == SpaceType.Continuous)
if (m_ModelRunner == null)
m_LastActionBuffer = new ActionBuffers(actions, Array.Empty<int>());
return ref m_LastActionBuffer;
m_LastActionBuffer = ActionBuffers.Empty;
m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions);
else
{
m_ModelRunner?.DecideBatch();
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
}
return ref m_LastActionBuffer;
}

12
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs


}
[Test]
public void TestFailOnMixedActionSpace()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
LogAssert.Expect(LogType.Assert, "Actuators on the same Agent must have the same action SpaceType.");
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
}
[Test]
public void TestFailOnSameActuatorName()
{
var manager = new ActuatorManager();

74
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs


using Unity.Barracuda;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Tests
{

[Test]
public void ApplyContinuousActionOutput()
{
var actionSpec = ActionSpec.MakeContinuous(3);
var inputTensor = new TensorProxy()
{
shape = new long[] { 2, 3 },

var applier = new ContinuousActionOutputApplier();
var applier = new ContinuousActionOutputApplier(actionSpec);
var actionDict = new Dictionary<int, float[]>() { { 0, null }, { 1, null } };
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
Assert.AreEqual(actionDict[0][0], 1);
Assert.AreEqual(actionDict[0][1], 2);
Assert.AreEqual(actionDict[0][2], 3);
Assert.AreEqual(actionDict[0].ContinuousActions[0], 1);
Assert.AreEqual(actionDict[0].ContinuousActions[1], 2);
Assert.AreEqual(actionDict[0].ContinuousActions[2], 3);
Assert.AreEqual(actionDict[1][0], 4);
Assert.AreEqual(actionDict[1][1], 5);
Assert.AreEqual(actionDict[1][2], 6);
Assert.AreEqual(actionDict[1].ContinuousActions[0], 4);
Assert.AreEqual(actionDict[1].ContinuousActions[1], 5);
Assert.AreEqual(actionDict[1].ContinuousActions[2], 6);
var actionSpec = ActionSpec.MakeDiscrete(new int[] { 2, 3 });
var inputTensor = new TensorProxy()
{
shape = new long[] { 2, 5 },

new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
var applier = new DiscreteActionOutputApplier(actionSpec, 0, alloc);
var actionDict = new Dictionary<int, float[]>() { { 0, null }, { 1, null } };
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
Assert.AreEqual(actionDict[0][0], 1);
Assert.AreEqual(actionDict[0][1], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1][0], 1);
Assert.AreEqual(actionDict[1][1], 2);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
[Test]
public void ApplyHybridActionOutput()
{
var actionSpec = new ActionSpec(3, 2, new int[] { 2, 3 });
var continuousInputTensor = new TensorProxy()
{
shape = new long[] { 2, 3 },
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var discreteInputTensor = new TensorProxy()
{
shape = new long[] { 2, 8 },
data = new Tensor(
2,
5,
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var continuousApplier = new ContinuousActionOutputApplier(actionSpec);
var alloc = new TensorCachingAllocator();
var discreteApplier = new DiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
continuousApplier.Apply(continuousInputTensor, agentIds, actionDict);
discreteApplier.Apply(discreteInputTensor, agentIds, actionDict);
Assert.AreEqual(actionDict[0].ContinuousActions[0], 1);
Assert.AreEqual(actionDict[0].ContinuousActions[1], 2);
Assert.AreEqual(actionDict[0].ContinuousActions[2], 3);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1].ContinuousActions[0], 4);
Assert.AreEqual(actionDict[1].ContinuousActions[1], 5);
Assert.AreEqual(actionDict[1].ContinuousActions[2], 6);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
}

62
com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs


using Unity.Barracuda;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Tests

{
const string k_continuous2vis8vec2actionPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.nn";
const string k_discrete1vis0vec_2_3action_recurrModelPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.nn";
NNModel continuous2vis8vec2actionModel;
NNModel discrete1vis0vec_2_3action_recurrModel;
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx";
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx";
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn";
NNModel continuousONNXModel;
NNModel discreteONNXModel;
NNModel hybridONNXModel;
NNModel continuousNNModel;
NNModel discreteNNModel;
Test3DSensorComponent sensor_21_20_3;
Test3DSensorComponent sensor_20_22_3;

return ActionSpec.MakeDiscrete(2, 3);
}
ActionSpec GetHybrid0vis53vec_3c_2dActionSpec()
{
return new ActionSpec(3, 1, new int[] { 2 });
}
continuous2vis8vec2actionModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuous2vis8vec2actionPath, typeof(NNModel));
discrete1vis0vec_2_3action_recurrModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discrete1vis0vec_2_3action_recurrModelPath, typeof(NNModel));
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel));
discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel));
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel));
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel));
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>();
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3);

[Test]
public void TestModelExist()
{
Assert.IsNotNull(continuous2vis8vec2actionModel);
Assert.IsNotNull(discrete1vis0vec_2_3action_recurrModel);
Assert.IsNotNull(continuousONNXModel);
Assert.IsNotNull(discreteONNXModel);
Assert.IsNotNull(hybridONNXModel);
Assert.IsNotNull(continuousNNModel);
Assert.IsNotNull(discreteNNModel);
var modelRunner = new ModelRunner(continuous2vis8vec2actionModel, GetContinuous2vis8vec2actionActionSpec());
var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec());
modelRunner = new ModelRunner(discrete1vis0vec_2_3action_recurrModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec());
modelRunner = new ModelRunner(discreteONNXModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec());
modelRunner.Dispose();
modelRunner = new ModelRunner(hybridONNXModel, GetHybrid0vis53vec_3c_2dActionSpec());
modelRunner.Dispose();
modelRunner = new ModelRunner(continuousNNModel, GetContinuous2vis8vec2actionActionSpec());
modelRunner.Dispose();
modelRunner = new ModelRunner(discreteNNModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec());
modelRunner.Dispose();
}

var modelRunner = new ModelRunner(continuous2vis8vec2actionModel, GetContinuous2vis8vec2actionActionSpec(), InferenceDevice.CPU);
Assert.True(modelRunner.HasModel(continuous2vis8vec2actionModel, InferenceDevice.CPU));
Assert.False(modelRunner.HasModel(continuous2vis8vec2actionModel, InferenceDevice.GPU));
Assert.False(modelRunner.HasModel(discrete1vis0vec_2_3action_recurrModel, InferenceDevice.CPU));
var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec(), InferenceDevice.CPU);
Assert.True(modelRunner.HasModel(continuousONNXModel, InferenceDevice.CPU));
Assert.False(modelRunner.HasModel(continuousONNXModel, InferenceDevice.GPU));
Assert.False(modelRunner.HasModel(discreteONNXModel, InferenceDevice.CPU));
modelRunner.Dispose();
}

var actionSpec = GetDiscrete1vis0vec_2_3action_recurrModelActionSpec();
var modelRunner = new ModelRunner(discrete1vis0vec_2_3action_recurrModel, actionSpec);
var modelRunner = new ModelRunner(discreteONNXModel, actionSpec);
var info1 = new AgentInfo();
info1.episodeId = 1;
modelRunner.PutObservations(info1, new[] { sensor_21_20_3.CreateSensor() }.ToList());

modelRunner.DecideBatch();
Assert.IsNotNull(modelRunner.GetAction(1));
Assert.IsNotNull(modelRunner.GetAction(2));
Assert.IsNull(modelRunner.GetAction(3));
Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).Count());
Assert.IsFalse(modelRunner.GetAction(1).Equals(ActionBuffers.Empty));
Assert.IsFalse(modelRunner.GetAction(2).Equals(ActionBuffers.Empty));
Assert.IsTrue(modelRunner.GetAction(3).Equals(ActionBuffers.Empty));
Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length);
modelRunner.Dispose();
}
}

190
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


[TestFixture]
public class ParameterLoaderTest
{
const string k_continuous2vis8vec2actionPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.nn";
const string k_discrete1vis0vec_2_3action_recurrModelPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.nn";
NNModel continuous2vis8vec2actionModel;
NNModel discrete1vis0vec_2_3action_recurrModel;
// ONNX model with continuous/discrete action output (support hybrid action)
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx";
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx";
// NN model with single action output (deprecated, does not support hybrid action).
// Same BrainParameters settings as the corresponding ONNX model.
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn";
NNModel continuousONNXModel;
NNModel discreteONNXModel;
NNModel hybridONNXModel;
NNModel continuousNNModel;
NNModel discreteNNModel;
Test3DSensorComponent sensor_21_20_3;
Test3DSensorComponent sensor_20_22_3;

return validBrainParameters;
}
// TODO: update and enable this after integrating action spec into BrainParameters
// BrainParameters GetHybridBrainParameters()
// {
// var validBrainParameters = new BrainParameters();
// validBrainParameters.VectorObservationSize = 53;
// validBrainParameters.VectorActionSize = new[] { 2 };
// validBrainParameters.NumStackedVectorObservations = 1;
// validBrainParameters.VectorActionSpaceType = SpaceType.Discrete;
// return validBrainParameters;
// }
continuous2vis8vec2actionModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuous2vis8vec2actionPath, typeof(NNModel));
discrete1vis0vec_2_3action_recurrModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discrete1vis0vec_2_3action_recurrModelPath, typeof(NNModel));
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel));
discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel));
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel));
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel));
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>();
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3);

[Test]
public void TestModelExist()
{
Assert.IsNotNull(continuous2vis8vec2actionModel);
Assert.IsNotNull(discrete1vis0vec_2_3action_recurrModel);
Assert.IsNotNull(continuousONNXModel);
Assert.IsNotNull(discreteONNXModel);
Assert.IsNotNull(hybridONNXModel);
Assert.IsNotNull(continuousNNModel);
Assert.IsNotNull(discreteNNModel);
[Test]
public void TestGetInputTensors1()
[TestCase(true)]
[TestCase(false)]
public void TestGetInputTensorsContinuous(bool useDeprecatedNNModel)
var model = ModelLoader.Load(continuous2vis8vec2actionModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
// Model should contain 3 inputs : vector, visual 1 and visual 2

Assert.AreEqual(0, BarracudaModelParamLoader.GetNumVisualInputs(null));
}
[Test]
public void TestGetInputTensors2()
[TestCase(true)]
[TestCase(false)]
public void TestGetInputTensorsDiscrete(bool useDeprecatedNNModel)
var model = ModelLoader.Load(discrete1vis0vec_2_3action_recurrModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
// Model should contain 2 inputs : recurrent and visual 1

}
[Test]
public void TestGetOutputTensors1()
public void TestGetInputTensorsHybrid()
var model = ModelLoader.Load(continuous2vis8vec2actionModel);
var model = ModelLoader.Load(hybridONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames);
}
[TestCase(true)]
[TestCase(false)]
public void TestGetOutputTensorsContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
Assert.Contains(TensorNames.ActionOutput, outputNames);
var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.ContinuousActionOutput;
Assert.Contains(actionOutputName, outputNames);
[Test]
public void TestGetOutputTensors2()
[TestCase(true)]
[TestCase(false)]
public void TestGetOutputTensorsDiscrete(bool useDeprecatedNNModel)
var model = ModelLoader.Load(discrete1vis0vec_2_3action_recurrModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
Assert.Contains(TensorNames.ActionOutput, outputNames);
var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.DiscreteActionOutput;
Assert.Contains(actionOutputName, outputNames);
public void TestCheckModelValid1()
public void TestGetOutputTensorsHybrid()
var model = ModelLoader.Load(continuous2vis8vec2actionModel);
var model = ModelLoader.Load(hybridONNXModel);
var outputNames = BarracudaModelParamLoader.GetOutputNames(model);
Assert.AreEqual(2, outputNames.Count());
Assert.Contains(TensorNames.ContinuousActionOutput, outputNames);
Assert.Contains(TensorNames.DiscreteActionOutput, outputNames);
Assert.AreEqual(0, BarracudaModelParamLoader.GetOutputNames(null).Count());
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelValidContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var validBrainParameters = GetContinuous2vis8vec2actionBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(

Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
[Test]
public void TestCheckModelValid2()
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelValidDiscrete(bool useDeprecatedNNModel)
var model = ModelLoader.Load(discrete1vis0vec_2_3action_recurrModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var validBrainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(

Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
[Test]
public void TestCheckModelThrowsVectorObservation1()
// TODO: update and enable this test after integrating action spec into BrainParameters
// [Test]
// public void TestCheckModelValidHybrid()
// {
// var model = ModelLoader.Load(hybridModel);
// var validBrainParameters = GetHybridBrainParameters();
// var errors = BarracudaModelParamLoader.CheckModel(
// model, validBrainParameters,
// new SensorComponent[] { }, new ActuatorComponent[0]
// );
// Assert.AreEqual(0, errors.Count()); // There should not be any errors
// }
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNNModel)
var model = ModelLoader.Load(continuous2vis8vec2actionModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.VectorObservationSize = 9; // Invalid observation

Assert.Greater(errors.Count(), 0);
}
[Test]
public void TestCheckModelThrowsVectorObservation2()
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsVectorObservationDiscrete(bool useDeprecatedNNModel)
var model = ModelLoader.Load(discrete1vis0vec_2_3action_recurrModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.VectorObservationSize = 1; // Invalid observation

[Test]
public void TestCheckModelThrowsAction1()
// TODO: update and enable this test after integrating action spec into BrainParameters
// [Test]
// public void TestCheckModelThrowsVectorObservationHybrid()
// {
// var model = ModelLoader.Load(hybridModel);
// var brainParameters = GetHybridBrainParameters();
// brainParameters.VectorObservationSize = 9; // Invalid observation
// var errors = BarracudaModelParamLoader.CheckModel(
// model, brainParameters,
// new SensorComponent[] { }, new ActuatorComponent[0]
// );
// Assert.Greater(errors.Count(), 0);
// brainParameters = GetContinuous2vis8vec2actionBrainParameters();
// brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
// errors = BarracudaModelParamLoader.CheckModel(
// model, brainParameters,
// new SensorComponent[] { }, new ActuatorComponent[0]
// );
// Assert.Greater(errors.Count(), 0);
// }
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsActionContinuous(bool useDeprecatedNNModel)
var model = ModelLoader.Load(continuous2vis8vec2actionModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.VectorActionSize = new[] { 3 }; // Invalid action

Assert.Greater(errors.Count(), 0);
}
[Test]
public void TestCheckModelThrowsAction2()
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsActionDiscrete(bool useDeprecatedNNModel)
var model = ModelLoader.Load(discrete1vis0vec_2_3action_recurrModel);
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.VectorActionSize = new[] { 3, 3 }; // Invalid action

errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new SensorComponent[] { sensor_21_20_3 }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
// TODO: update and enable this test after integrating action spec into BrainParameters
// [Test]
// public void TestCheckModelThrowsActionHybrid()
// {
// var model = ModelLoader.Load(hybridModel);
// var brainParameters = GetHybridBrainParameters();
// brainParameters.VectorActionSize = new[] { 3 }; // Invalid action
// var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0]);
// Assert.Greater(errors.Count(), 0);
// brainParameters = GetContinuous2vis8vec2actionBrainParameters();
// brainParameters.VectorActionSpaceType = SpaceType.Discrete;// Invalid SpaceType
// errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0]);
// Assert.Greater(errors.Count(), 0);
// }
[Test]
public void TestCheckModelThrowsNoModel()

2
com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta


fileFormatVersion: 2
guid: a75582ff670094ff2996c1c4ab9dfd15
guid: bf4543cc3c6944794bbba065bdf90079
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

2
com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta


fileFormatVersion: 2
guid: 8a92fbcd96caa4ef5a93dd55c0c36705
guid: 6d6040ad621454dd5b713beb5483e347
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

141
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


using Unity.Barracuda;
namespace Unity.MLAgents.Inference
{
/// <summary>
/// Barracuda Model extension methods.
/// </summary>
internal static class BarracudaModelExtensions
{
/// <summary>
/// Check if the model has continuous action outputs.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>True if the model has continuous action outputs.</returns>
public static bool HasContinuousOutputs(this Model model)
{
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0;
}
else
{
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
}
/// <summary>
/// Continuous action output size of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Size of continuous action output.</returns>
public static int ContinuousOutputSize(this Model model)
{
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?
(int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0;
}
else
{
return (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0];
}
}
/// <summary>
/// Continuous action output tensor name of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Tensor name of continuous action output.</returns>
public static string ContinuousOutputName(this Model model)
{
if (model.UseDeprecated())
{
return TensorNames.ActionOutputDeprecated;
}
else
{
return TensorNames.ContinuousActionOutput;
}
}
/// <summary>
/// Check if the model has discrete action outputs.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>True if the model has discrete action outputs.</returns>
public static bool HasDiscreteOutputs(this Model model)
{
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0;
}
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
}
}
/// <summary>
/// Discrete action output size of the model. This is equal to the sum of the branch sizes.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Size of discrete action output.</returns>
public static int DiscreteOutputSize(this Model model)
{
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?
0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0];
}
else
{
return (int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0];
}
}
/// <summary>
/// Discrete action output tensor name of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Tensor name of discrete action output.</returns>
public static string DiscreteOutputName(this Model model)
{
if (model.UseDeprecated())
{
return TensorNames.ActionOutputDeprecated;
}
else
{
return TensorNames.DiscreteActionOutput;
}
}
/// <summary>
/// Check if the model uses deprecated output fields and should be handled differently.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>True if the model uses deprecated output fields.</returns>
public static bool UseDeprecated(this Model model)
{
return !model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput);
}
}
}

11
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta


fileFormatVersion: 2
guid: 1193c3bef93464baca0d8ba2d6ce1754
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

1001
com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
文件差异内容过多而无法显示
查看文件

14
com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta


fileFormatVersion: 2
guid: f90bffb60a3784a2385299a321f354a6
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj
11400002: model data
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3}
optimizeModel: 1
forceArbitraryBatchSize: 1
treatErrorsAsWarnings: 0

867
com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx


pytorch1.7:��
�
visual_observation_0
5network_body.visual_processors.0.conv_layers.0.weight
3network_body.visual_processors.0.conv_layers.0.bias35Conv_0"Conv*
dilations@@�*
group�*
kernel_shape@@�*
pads@@@@�*
strides@@�
1
3536 LeakyRelu_1" LeakyRelu*
alpha
�#<�
�
36
5network_body.visual_processors.0.conv_layers.2.weight
3network_body.visual_processors.0.conv_layers.2.bias37Conv_2"Conv*
dilations@@�*
group�*
kernel_shape@@�*
pads@@@@�*
strides@@�
1
3738 LeakyRelu_3" LeakyRelu*
alpha
�#<�
>39
Constant_4"Constant*"
value*J�������� �
38
3940 Reshape_5"Reshape
�
40
/network_body.visual_processors.0.dense.0.weight
-network_body.visual_processors.0.dense.0.bias41Gemm_6"Gemm*
alpha�?�*
beta�?�*
transB�
1
4142 LeakyRelu_7" LeakyRelu*