浏览代码
Barracuda inference for hybrid actions (#4611)
Barracuda inference for hybrid actions (#4611)
* TensorApplier.IApplier takes ActionBuffers instead of float[] as input argument * Model output format changed/MLA-1734-demo-provider
GitHub
4 年前
当前提交
5e5ff19b
共有 36 个文件被更改,包括 2958 次插入 和 238 次删除
-
2.yamato/com.unity.ml-agents-performance.yml
-
2.yamato/com.unity.ml-agents-test.yml
-
2.yamato/compressed-sensor-test.yml
-
2.yamato/gym-interface-test.yml
-
2.yamato/protobuf-generation-test.yml
-
2.yamato/python-ll-api-test.yml
-
2.yamato/standalone-build-test.yml
-
2.yamato/training-int-tests.yml
-
8com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
-
17com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
-
5com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
44com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
-
136com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
-
8com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
-
35com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
-
26com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
-
15com.unity.ml-agents/Runtime/Inference/TensorNames.cs
-
20com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
-
12com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
-
74com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
-
62com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
-
190com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
-
2com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
-
2com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
-
141com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
-
11com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta
-
1001com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta
-
867com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx.meta
-
462com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx.meta
-
0/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn
-
0/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn
-
0/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
-
0/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
|
|||
using Unity.Barracuda; |
|||
|
|||
namespace Unity.MLAgents.Inference |
|||
{ |
|||
/// <summary>
|
|||
/// Barracuda Model extension methods.
|
|||
/// </summary>
|
|||
internal static class BarracudaModelExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Check if the model has continuous action outputs.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model has continuous action outputs.</returns>
|
|||
public static bool HasContinuousOutputs(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; |
|||
} |
|||
else |
|||
{ |
|||
return model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
|||
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Continuous action output size of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Size of continuous action output.</returns>
|
|||
public static int ContinuousOutputSize(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|||
(int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; |
|||
} |
|||
else |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0]; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Continuous action output tensor name of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Tensor name of continuous action output.</returns>
|
|||
public static string ContinuousOutputName(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return TensorNames.ActionOutputDeprecated; |
|||
} |
|||
else |
|||
{ |
|||
return TensorNames.ContinuousActionOutput; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check if the model has discrete action outputs.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model has discrete action outputs.</returns>
|
|||
public static bool HasDiscreteOutputs(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; |
|||
} |
|||
else |
|||
{ |
|||
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && |
|||
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Discrete action output size of the model. This is equal to the sum of the branch sizes.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Size of discrete action output.</returns>
|
|||
public static int DiscreteOutputSize(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|||
0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; |
|||
} |
|||
else |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0]; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Discrete action output tensor name of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Tensor name of discrete action output.</returns>
|
|||
public static string DiscreteOutputName(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return TensorNames.ActionOutputDeprecated; |
|||
} |
|||
else |
|||
{ |
|||
return TensorNames.DiscreteActionOutput; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check if the model uses deprecated output fields and should be handled differently.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model uses deprecated output fields.</returns>
|
|||
public static bool UseDeprecated(this Model model) |
|||
{ |
|||
return !model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
|||
!model.outputs.Contains(TensorNames.DiscreteActionOutput); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 1193c3bef93464baca0d8ba2d6ce1754 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
1001
com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
fileFormatVersion: 2 |
|||
guid: f90bffb60a3784a2385299a321f354a6 |
|||
ScriptedImporter: |
|||
fileIDToRecycleName: |
|||
11400000: main obj |
|||
11400002: model data |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|||
script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} |
|||
optimizeModel: 1 |
|||
forceArbitraryBatchSize: 1 |
|||
treatErrorsAsWarnings: 0 |
|
|||
pytorch1.7:�� |
|||
� |
|||
visual_observation_0 |
|||
5network_body.visual_processors.0.conv_layers.0.weight |
|||
3network_body.visual_processors.0.conv_layers.0.bias35Conv_0"Conv* |
|||
dilations@@�* |
|||
group�* |
|||
kernel_shape@@�* |
|||
pads@ @ @ @ �* |
|||
strides@@� |
|||
1 |
|||
3536LeakyRelu_1" LeakyRelu* |
|||
alpha |
|||
�#<� |
|||
� |
|||
36 |
|||
5network_body.visual_processors.0.conv_layers.2.weight |
|||
3network_body.visual_processors.0.conv_layers.2.bias37Conv_2"Conv* |
|||
dilations@@�* |
|||
group�* |
|||
kernel_shape@@�* |
|||
pads@ @ @ @ �* |
|||
strides@@� |
|||
1 |
|||
3738LeakyRelu_3" LeakyRelu* |
|||
alpha |
|||
�#<� |
|||
>39 |
|||
Constant_4"Constant*" |
|||
value*J�������� � |
|||
|
|||
38 |
|||
3940 Reshape_5"Reshape |
|||
� |
|||
40 |
|||
/network_body.visual_processors.0.dense.0.weight |
|||
-network_body.visual_processors.0.dense.0.bias41Gemm_6"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
1 |
|||
4142LeakyRelu_7" LeakyRelu* |
|||