浏览代码

Merge remote-tracking branch 'origin/v2-staging' into csharp-obs-spec

/v2-staging-rebase
Chris Elion 4 年前
当前提交
a362b3d9
共有 13 个文件被更改,包括 696 次插入220 次删除
  1. 5
      .pre-commit-config.yaml
  2. 2
      com.unity.ml-agents/CHANGELOG.md
  3. 16
      com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
  4. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  5. 70
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  6. 452
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  7. 10
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  8. 92
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  9. 41
      com.unity.ml-agents/Tests/Editor/DiscreteActionOutputApplierTest.cs
  10. 77
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  11. 2
      ml-agents/mlagents/trainers/torch/distributions.py
  12. 61
      ml-agents/mlagents/trainers/torch/model_serialization.py
  13. 44
      ml-agents/mlagents/trainers/torch/networks.py

5
.pre-commit-config.yaml


name: validate release links
language: script
entry: utils/validate_release_links.py
- id: dotnet-format
name: dotnet-format
language: script
entry: utils/run_dotnet_format.py
types: [c#]

2
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- The `.onnx` models input names have changed. All input placeholders will now use the prefix `obs_` removing the distinction between visual and vector observations. Models created with this version will not be usable with previous versions of the package (#5080)
- The `.onnx` models discrete action output now contains the discrete actions values and not the logits. Models created with this version will not be usable with previous versions of the package (#5080)
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Bug Fixes

16
com.unity.ml-agents/Editor/BehaviorParametersEditor.cs


using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum;
namespace Unity.MLAgents.Editor
{

{
if (check != null)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
switch (check.CheckType)
{
case CheckTypeEnum.Info:
EditorGUILayout.HelpBox(check.Message, MessageType.Info);
break;
case CheckTypeEnum.Warning:
EditorGUILayout.HelpBox(check.Message, MessageType.Warning);
break;
case CheckTypeEnum.Error:
EditorGUILayout.HelpBox(check.Message, MessageType.Error);
break;
default:
break;
}
}
}
}

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


}
/// <summary>
/// The Applier for the Discrete Action output tensor.
/// </summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
{
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
}
var discreteBuffer = actionBuffer.DiscreteActions;
for (var j = 0; j < actionSize; j++)
{
discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j];
}
}
agentIndex++;
}
}
}
/// <summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;

public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSize = actionSpec.BranchSizes;
m_Multinomial = new Multinomial(seed);

70
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


using System.Collections.Generic;
using System.Linq;
using Unity.Barracuda;
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;
namespace Unity.MLAgents.Inference
{

names.Sort();
return names.ToArray();
}
/// <summary>
/// Get the version of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>The api version of the model</returns>
public static int GetVersion(this Model model)
{
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
}
/// <summary>

else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
(int)model.DiscreteOutputSize() > 0;
/// This method gets the tensor representing the list of branch size and returns the
/// sum of all the elements in the Tensor.
/// - In version 1.X this tensor contains a single number, the sum of all branch
/// size values.
/// - In version 2.X this tensor contains a 1D Tensor with each element corresponding
/// to a branch size.
/// Since this method does the sum of all elements in the tensor, the output
/// will be the same on both 1.X and 2.X.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.

else
{
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
if (discreteOutputShape == null)
{
return 0;
}
else
{
int result = 0;
for (int i = 0; i < discreteOutputShape.length; i++)
{
result += (int)discreteOutputShape[i];
}
return result;
}
}
}

/// <param name="failedModelChecks">Output list of failure messages</param>
///
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<string> failedModelChecks)
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
);
return false;
}

{
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
);
return false;
}

!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain any Action Output Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
);
return false;
}

if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
{
failedModelChecks.Add("The model does not contain any Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
);
failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " +
"This is only required for model that uses a deprecated model format.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
"not found in the model file. " +
"This is only required for model that uses a deprecated model format.")
);
return false;
}
}

model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
}
}

452
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// </summary>
internal class BarracudaModelParamLoader
{
const long k_ApiVersion = 2;
internal enum ModelApiVersion
{
MLAgents1_0 = 2,
MLAgents2_0 = 3,
MinSupportedVersion = MLAgents1_0,
MaxSupportedVersion = MLAgents2_0
}
internal class FailedCheck
{
public enum CheckTypeEnum
{
Info = 0,
Warning = 1,
Error = 2
}
public CheckTypeEnum CheckType;
public string Message;
public static FailedCheck Info(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Info, Message = message };
}
public static FailedCheck Warning(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Warning, Message = message };
}
public static FailedCheck Error(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Error, Message = message };
}
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks

/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <returns>The list the error messages of the checks that failed</returns>
public static IEnumerable<string> CheckModel(Model model, BrainParameters brainParameters,
ISensor[] sensors, ActuatorComponent[] actuatorComponents,
/// <returns>A IEnumerable of the checks that failed</returns>
public static IEnumerable<FailedCheck> CheckModel(
Model model,
BrainParameters brainParameters,
ISensor[] sensors,
ActuatorComponent[] actuatorComponents,
BehaviorType behaviorType = BehaviorType.Default)
BehaviorType behaviorType = BehaviorType.Default
)
List<string> failedModelChecks = new List<string>();
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
if (model == null)
{
var errorMsg = "There is no model for this Brain; cannot run inference. ";

{
errorMsg += "(But can still train)";
}
failedModelChecks.Add(errorMsg);
failedModelChecks.Add(FailedCheck.Info(errorMsg));
return failedModelChecks;
}

return failedModelChecks;
}
var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
if (modelApiVersion == -1)
var modelApiVersion = model.GetVersion();
if (modelApiVersion < (int)ModelApiVersion.MinSupportedVersion || modelApiVersion > (int)ModelApiVersion.MaxSupportedVersion)
"Model was not trained using the right version of ML-Agents. " +
"Cannot use this model.");
return failedModelChecks;
}
if (modelApiVersion != k_ApiVersion)
{
failedModelChecks.Add(
$"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the Brain's version ({k_ApiVersion}).");
FailedCheck.Warning($"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the current range of supported versions: " +
$"({(int)ModelApiVersion.MinSupportedVersion} to {(int)ModelApiVersion.MaxSupportedVersion}).")
);
return failedModelChecks;
}

failedModelChecks.Add($"Missing node in the model provided : {TensorNames.MemorySize}");
failedModelChecks.Add(FailedCheck.Warning($"Missing node in the model provided : {TensorNames.MemorySize}"
));
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresenceLegacy(model, brainParameters, memorySize, sensors)
);
failedModelChecks.AddRange(
CheckInputTensorShapeLegacy(model, brainParameters, sensors, observableAttributeTotalSize)
);
}
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
);
}
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
);
return failedModelChecks;
}

/// present in the BrainParameters.
/// present in the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters

/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable<string> CheckInputTensorPresence(
static IEnumerable<FailedCheck> CheckInputTensorPresenceLegacy(
Model model,
BrainParameters brainParameters,
int memory,

var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
var tensorsNames = model.GetInputNames();
// If there is no Vector Observation Input but the Brain Parameters expect one.

failedModelChecks.Add(
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.");
FailedCheck.Warning("The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.")
);
}
// If there are not enough Visual Observation Input compared to what the

TensorNames.GetVisualObservationName(visObsIndex)))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
FailedCheck.Warning("The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).")
);
}
visObsIndex++;
}

TensorNames.GetObservationName(sensorIndex)))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
FailedCheck.Warning("The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).")
);
}
}

if (expectedVisualObs > visObsIndex)
{
failedModelChecks.Add(
$"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors."
);
FailedCheck.Warning($"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors.")
);
}
// If the model has a non-negative memory size but requires a recurrent input
if (memory > 0)
{
if (!tensorsNames.Any(x => x.EndsWith("_h")) ||
!tensorsNames.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.")
);
}
}
// If the model uses discrete control but does not have an input for action masks
if (model.HasDiscreteOutputs())
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.")
);
}
}
return failedModelChecks;
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="memory">
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
ISensor[] sensors
)
{
var failedModelChecks = new List<FailedCheck>();
var tensorsNames = model.GetInputNames();
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))
{
var sensor = sensors[sensorIndex];
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).")
);
}
}
// If the model has a non-negative memory size but requires a recurrent input

!tensorsNames.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
"The model does not contain a Recurrent Input Node but has memory_size.");
FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.")
);
}
}

if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
failedModelChecks.Add(
"The model does not contain an Action Mask but is using Discrete Control.");
FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.")
);
}
}
return failedModelChecks;

/// </param>
/// <param name="memory">The memory size that the model is expecting/</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed output presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable<string> CheckOutputTensorPresence(Model model, int memory)
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)

!memOutputs.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
"The model does not contain a Recurrent Output Node but has memory_size.");
FailedCheck.Warning("The model does not contain a Recurrent Output Node but has memory_size.")
);
}
}
return failedModelChecks;

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVisualObsShape(
static FailedCheck CheckVisualObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationSpec().Shape;

var pixelT = tensorProxy.Channels;
if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
{
return $"The visual Observation of the model does not match. " +
return FailedCheck.Warning($"The visual Observation of the model does not match. " +
$"was expecting [?x{widthT}x{heightT}x{pixelT}].";
$"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor."
);
}
return null;
}

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
static FailedCheck CheckRankTwoObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationSpec().Shape;

var dim2T = tensorProxy.Width;
var dim3T = tensorProxy.Height;
return $"An Observation of the model does not match. " +
var proxyDimStr = $"[?x{dim1T}x{dim2T}]";
if (dim3T > 1)
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
}
return FailedCheck.Warning($"An Observation of the model does not match. " +
$"was expecting [?x{dim1T}x{dim2T}].";
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
);
}
return null;
}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensor">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckRankOneObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;
var dim3T = tensorProxy.Height;
if ((dim1Bp != dim1T))
{
var proxyDimStr = $"[?x{dim1T}]";
if (dim2T > 1)
{
proxyDimStr = $"[?x{dim1T}x{dim2T}]";
}
if (dim3T > 1)
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
}
return FailedCheck.Warning($"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}] but " +
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
);
}
return null;
}

/// the model and the BrainParameters.
/// the model and the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters

/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckInputTensorShape(
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable<FailedCheck> CheckInputTensorShapeLegacy(
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>()
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShape},
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShapeLegacy},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},

if (!tensor.name.Contains("visual_observation"))
{
failedModelChecks.Add(
"Model requires an unknown input named : " + tensor.name);
FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name)
);
}
}
else

/// <summary>
/// Checks that the shape of the Vector Observation input placeholder is the same in the
/// model and in the Brain Parameters.
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVectorObsShape(
static FailedCheck CheckVectorObsShapeLegacy(
BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors,
int observableAttributeTotalSize)
{

}
sensorSizes += "]";
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
return FailedCheck.Warning(
$"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"Sensor sizes: {sensorSizes}.";
$"Sensor sizes: {sensorSizes}."
);
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable<FailedCheck> CheckInputTensorShape(
Model model, BrainParameters brainParameters, ISensor[] sensors,
int observableAttributeTotalSize)
{
var failedModelChecks = new List<FailedCheck>();
var tensorTester =
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>()
{
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
};
foreach (var mem in model.memories)
{
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationSpec().NumDimensions == 3)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens);
}
if (sens.GetObservationSpec().NumDimensions == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
if (sens.GetObservationSpec().NumDimensions == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);
}
}
// If the model expects an input but it is not in this list
foreach (var tensor in model.GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.name))
{
failedModelChecks.Add(FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name
));
}
else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);
}
}
}
return failedModelChecks;
}
/// <summary>
/// Checks that the shape of the Previous Vector Action input placeholder is the same in the
/// model and in the Brain Parameters.

/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckPreviousActionShape(
static FailedCheck CheckPreviousActionShape(
BrainParameters brainParameters, TensorProxy tensorProxy,
ISensor[] sensors, int observableAttributeTotalSize)
{

{
return "Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}.";
return FailedCheck.Warning("Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}."
);
}
return null;
}

/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <returns>
/// A IEnumerable of string corresponding to the incompatible shapes between model
/// A IEnumerable of error messages corresponding to the incompatible shapes between model
static IEnumerable<string> CheckOutputTensorShape(
static IEnumerable<FailedCheck> CheckOutputTensorShape(
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
// If the model expects an output but it is not in this list
var modelContinuousActionSize = model.ContinuousOutputSize();

failedModelChecks.Add(continuousError);
}
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
var discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
FailedCheck discreteError = null;
var modelApiVersion = model.GetVersion();
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
{
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
}
if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
var modelDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelDiscreteBranches);
}
if (discreteError != null)
{
failedModelChecks.Add(discreteError);

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
/// </param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckDiscreteActionOutputShape(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches)
{
var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList();
foreach (var actuatorComponent in actuatorComponents)
{
var actionSpec = actuatorComponent.ActionSpec;
discreteActionBranches.AddRange(actionSpec.BranchSizes);
}
int modelDiscreteBranchesLength = modelDiscreteBranches?.length ?? 0;
if (modelDiscreteBranchesLength != discreteActionBranches.Count)
{
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranchesLength}."
);
}
for (int i = 0; i < modelDiscreteBranchesLength; i++)
{
if (modelDiscreteBranches[i] != discreteActionBranches[i])
{
return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " +
$"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} "
);
}
}
return null;
}
/// <summary>
/// Checks that the shape of the discrete action output is the same in the
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelSumDiscreteBranchSizes">
/// The size of the discrete action output that is expected by the model.
/// </param>

/// </returns>
static string CheckDiscreteActionOutputShape(
static FailedCheck CheckDiscreteActionOutputShapeLegacy(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes)
{
// TODO: check each branch size instead of sum of branch sizes

if (modelSumDiscreteBranchSizes != sumOfDiscreteBranchSizes)
{
return "Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}.";
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}."
);
}
return null;
}

/// </param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckContinuousActionOutputShape(
static FailedCheck CheckContinuousActionOutputShape(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelContinuousActionSize)
{
var numContinuousActions = brainParameters.ActionSpec.NumContinuousActions;

if (modelContinuousActionSize != numContinuousActions)
{
return "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}.";
return FailedCheck.Warning(
"Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}."
);
}
return null;
}

10
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


if (actionSpec.NumDiscreteActions > 0)
{
var tensorName = model.DiscreteOutputName();
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
var modelVersion = model.GetVersion();
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator);
}
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
}
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);

92
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


}
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();
int m_ApiVersion;
/// <summary>
/// Returns a new TensorGenerators object.

return;
}
var model = (Model)barracudaModel;
m_ApiVersion = model.GetVersion();
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =

public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationSpec().Shape;
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
var sensor = sensors[sensorIndex];
var rank = sensor.GetObservationSpec().NumDimensions;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
{
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var obsGen = new ObservationGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
}

41
com.unity.ml-agents/Tests/Editor/DiscreteActionOutputApplierTest.cs


namespace Unity.MLAgents.Tests
{
public class DiscreteActionOutputApplierTest
{
[Test]

var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null);
var agentIds = new List<int> { 42, 1337 };
var actionBuffers = new Dictionary<int, ActionBuffers>();
actionBuffers[42] = new ActionBuffers(actionSpec);
actionBuffers[1337] = new ActionBuffers(actionSpec);
var actionTensor = new TensorProxy
{
data = new Tensor(
2,
2,
new[]
{
2.0f, // Agent 0, branch 0
1.0f, // Agent 0, branch 1
0.0f, // Agent 1, branch 0
0.0f // Agent 1, branch 1
}),
shape = new long[] { 2, 2 },
valueType = TensorProxy.TensorType.FloatingPoint
};
applier.Apply(actionTensor, agentIds, actionBuffers);
Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]);
Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]);
Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]);
Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]);
}
}
public class LegacyDiscreteActionOutputApplierTest
{
[Test]
public void TestDiscreteApply()
{
var actionSpec = ActionSpec.MakeDiscrete(3, 2);
const float smallLogProb = -1000.0f;
const float largeLogProb = -1.0f;

valueType = TensorProxy.TensorType.FloatingPoint
};
var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null);
var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 2020, null);
var agentIds = new List<int> { 42, 1337 };
var actionBuffers = new Dictionary<int, ActionBuffers>();
actionBuffers[42] = new ActionBuffers(actionSpec);

77
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs


}
[Test]
public void ApplyDiscreteActionOutput()
public void ApplyDiscreteActionOutputLegacy()
{
var actionSpec = ActionSpec.MakeDiscrete(2, 3);
var inputTensor = new TensorProxy()

new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var alloc = new TensorCachingAllocator();
var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
applier.Apply(inputTensor, agentIds, actionDict);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
[Test]
public void ApplyDiscreteActionOutput()
{
var actionSpec = ActionSpec.MakeDiscrete(2, 3);
var inputTensor = new TensorProxy()
{
shape = new long[] { 2, 2 },
data = new Tensor(
2,
2,
new[] { 1f, 1f, 1f, 2f }),
};
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };

}
[Test]
public void ApplyHybridActionOutput()
public void ApplyHybridActionOutputLegacy()
{
var actionSpec = new ActionSpec(3, new[] { 2, 3 });
var continuousInputTensor = new TensorProxy()

2,
5,
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var continuousApplier = new ContinuousActionOutputApplier(actionSpec);
var alloc = new TensorCachingAllocator();
var discreteApplier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
continuousApplier.Apply(continuousInputTensor, agentIds, actionDict);
discreteApplier.Apply(discreteInputTensor, agentIds, actionDict);
Assert.AreEqual(actionDict[0].ContinuousActions[0], 1);
Assert.AreEqual(actionDict[0].ContinuousActions[1], 2);
Assert.AreEqual(actionDict[0].ContinuousActions[2], 3);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1].ContinuousActions[0], 4);
Assert.AreEqual(actionDict[1].ContinuousActions[1], 5);
Assert.AreEqual(actionDict[1].ContinuousActions[2], 6);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
[Test]
public void ApplyHybridActionOutput()
{
var actionSpec = new ActionSpec(3, new[] { 2, 3 });
var continuousInputTensor = new TensorProxy()
{
shape = new long[] { 2, 3 },
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var discreteInputTensor = new TensorProxy()
{
shape = new long[] { 2, 2 },
data = new Tensor(
2,
2,
new[] { 1f, 1f, 1f, 2f }),
};
var continuousApplier = new ContinuousActionOutputApplier(actionSpec);
var alloc = new TensorCachingAllocator();

2
ml-agents/mlagents/trainers/torch/distributions.py


).unsqueeze(-1)
def exported_model_output(self):
return self.all_log_prob()
return self.sample()
class GaussianDistribution(nn.Module):

61
ml-agents/mlagents/trainers/torch/model_serialization.py


from typing import Tuple
import threading
from mlagents.torch_utils import torch

observation_specs = self.policy.behavior_spec.observation_specs
batch_dim = [1]
seq_len_dim = [1]
vec_obs_size = 0
for obs_spec in observation_specs:
if len(obs_spec.shape) == 1:
vec_obs_size += obs_spec.shape[0]
num_vis_obs = sum(
1 for obs_spec in observation_specs if len(obs_spec.shape) == 3
)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in observation_specs.shape)
dummy_vis_obs = [
num_obs = len(observation_specs)
dummy_obs = [
batch_dim + [obs_spec.shape[2], obs_spec.shape[0], obs_spec.shape[1]]
batch_dim + list(ModelSerializer._get_onnx_shape(obs_spec.shape))
if len(obs_spec.shape) == 3
]
dummy_var_len_obs = [
torch.zeros(batch_dim + [obs_spec.shape[0], obs_spec.shape[1]])
for obs_spec in observation_specs
if len(obs_spec.shape) == 2
]
dummy_masks = torch.ones(

batch_dim + seq_len_dim + [self.policy.export_memory_size]
)
self.dummy_input = (
dummy_vec_obs,
dummy_vis_obs,
dummy_var_len_obs,
dummy_masks,
dummy_memories,
)
self.dummy_input = (dummy_obs, dummy_masks, dummy_memories)
self.input_names = [TensorNames.vector_observation_placeholder]
for i in range(num_vis_obs):
self.input_names.append(TensorNames.get_visual_observation_name(i))
for i, obs_spec in enumerate(observation_specs):
if len(obs_spec.shape) == 2:
self.input_names.append(TensorNames.get_observation_name(i))
self.input_names = [TensorNames.get_observation_name(i) for i in range(num_obs)]
self.input_names += [
TensorNames.action_mask_placeholder,
TensorNames.recurrent_in_placeholder,

TensorNames.discrete_action_output_shape,
]
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})
if (
self.policy.behavior_spec.action_spec.continuous_size == 0
or self.policy.behavior_spec.action_spec.discrete_size == 0
):
self.output_names += [
TensorNames.action_output_deprecated,
TensorNames.is_continuous_control_deprecated,
TensorNames.action_output_shape_deprecated,
]
self.dynamic_axes.update(
{TensorNames.action_output_deprecated: {0: "batch"}}
)
@staticmethod
def _get_onnx_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""
Converts the shape of an observation to be compatible with the NCHW format
of ONNX
"""
if len(shape) == 3:
return shape[2], shape[0], shape[1]
return shape
def export_policy_model(self, output_filepath: str) -> None:
"""

44
ml-agents/mlagents/trainers/torch/networks.py


@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

class SimpleActor(nn.Module, Actor):
MODEL_EXPORT_VERSION = 3
def __init__(
self,
observation_specs: List[ObservationSpec],

super().__init__()
self.action_spec = action_spec
self.version_number = torch.nn.Parameter(
torch.Tensor([2.0]), requires_grad=False
torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False
)
self.is_continuous_int_deprecated = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False

)
# TODO: export list of branch sizes instead of sum
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False
torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False
)
self.act_size_vector_deprecated = torch.nn.Parameter(
torch.Tensor(

def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

At this moment, torch.onnx.export() doesn't accept None as tensor to be exported,
so the size of return tuple varies with action spec.
"""
# This code will convert the vec and vis obs into a list of inputs for the network
concatenated_vec_obs = vec_inputs[0]
inputs = []
start = 0
end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
elif isinstance(enc, EntityEmbedding):
inputs.append(var_len_inputs[var_len_index])
var_len_index += 1
else: # visual input
inputs.append(vis_inputs[vis_index])
vis_index += 1
# End of code to convert the vec and vis obs into a list of inputs for the network
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=1
)

export_out += [cont_action_out, self.continuous_act_size_vector]
if self.action_spec.discrete_size > 0:
export_out += [disc_action_out, self.discrete_act_size_vector]
# Only export deprecated nodes with non-hybrid action spec
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0:
export_out += [
action_out_deprecated,
self.is_continuous_int_deprecated,
self.act_size_vector_deprecated,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)

正在加载...
取消
保存