浏览代码
Barracuda inference for hybrid actions (#4611)
Barracuda inference for hybrid actions (#4611)
* TensorApplier.IApplier takes ActionBuffers instead of float[] as input argument * Model output format changed/MLA-1734-demo-provider
GitHub
4 年前
当前提交
5e5ff19b
共有 36 个文件被更改,包括 2958 次插入 和 238 次删除
-
2.yamato/com.unity.ml-agents-performance.yml
-
2.yamato/com.unity.ml-agents-test.yml
-
2.yamato/compressed-sensor-test.yml
-
2.yamato/gym-interface-test.yml
-
2.yamato/protobuf-generation-test.yml
-
2.yamato/python-ll-api-test.yml
-
2.yamato/standalone-build-test.yml
-
2.yamato/training-int-tests.yml
-
8com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
-
17com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
-
5com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
44com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
-
136com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
-
8com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
-
35com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
-
26com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
-
15com.unity.ml-agents/Runtime/Inference/TensorNames.cs
-
20com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
-
12com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
-
74com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
-
62com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
-
190com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
-
2com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
-
2com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
-
141com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
-
11com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta
-
1001com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta
-
867com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx.meta
-
462com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx
-
14com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx.meta
-
0/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn
-
0/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn
-
0/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
-
0/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
|
|||
using Unity.Barracuda; |
|||
|
|||
namespace Unity.MLAgents.Inference |
|||
{ |
|||
/// <summary>
|
|||
/// Barracuda Model extension methods.
|
|||
/// </summary>
|
|||
internal static class BarracudaModelExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Check if the model has continuous action outputs.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model has continuous action outputs.</returns>
|
|||
public static bool HasContinuousOutputs(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; |
|||
} |
|||
else |
|||
{ |
|||
return model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
|||
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Continuous action output size of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Size of continuous action output.</returns>
|
|||
public static int ContinuousOutputSize(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|||
(int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; |
|||
} |
|||
else |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0]; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Continuous action output tensor name of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Tensor name of continuous action output.</returns>
|
|||
public static string ContinuousOutputName(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return TensorNames.ActionOutputDeprecated; |
|||
} |
|||
else |
|||
{ |
|||
return TensorNames.ContinuousActionOutput; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check if the model has discrete action outputs.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model has discrete action outputs.</returns>
|
|||
public static bool HasDiscreteOutputs(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; |
|||
} |
|||
else |
|||
{ |
|||
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && |
|||
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Discrete action output size of the model. This is equal to the sum of the branch sizes.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Size of discrete action output.</returns>
|
|||
public static int DiscreteOutputSize(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|||
0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; |
|||
} |
|||
else |
|||
{ |
|||
return (int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0]; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Discrete action output tensor name of the model.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>Tensor name of discrete action output.</returns>
|
|||
public static string DiscreteOutputName(this Model model) |
|||
{ |
|||
if (model.UseDeprecated()) |
|||
{ |
|||
return TensorNames.ActionOutputDeprecated; |
|||
} |
|||
else |
|||
{ |
|||
return TensorNames.DiscreteActionOutput; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check if the model uses deprecated output fields and should be handled differently.
|
|||
/// </summary>
|
|||
/// <param name="model">
|
|||
/// The Barracuda engine model for loading static parameters.
|
|||
/// </param>
|
|||
/// <returns>True if the model uses deprecated output fields.</returns>
|
|||
public static bool UseDeprecated(this Model model) |
|||
{ |
|||
return !model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
|||
!model.outputs.Contains(TensorNames.DiscreteActionOutput); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 1193c3bef93464baca0d8ba2d6ce1754 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
1001
com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
fileFormatVersion: 2 |
|||
guid: f90bffb60a3784a2385299a321f354a6 |
|||
ScriptedImporter: |
|||
fileIDToRecycleName: |
|||
11400000: main obj |
|||
11400002: model data |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|||
script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} |
|||
optimizeModel: 1 |
|||
forceArbitraryBatchSize: 1 |
|||
treatErrorsAsWarnings: 0 |
|
|||
pytorch1.7:�� |
|||
� |
|||
visual_observation_0 |
|||
5network_body.visual_processors.0.conv_layers.0.weight |
|||
3network_body.visual_processors.0.conv_layers.0.bias35Conv_0"Conv* |
|||
dilations@@�* |
|||
group�* |
|||
kernel_shape@@�* |
|||
pads@ @ @ @ �* |
|||
strides@@� |
|||
1 |
|||
3536LeakyRelu_1" LeakyRelu* |
|||
alpha |
|||
�#<� |
|||
� |
|||
36 |
|||
5network_body.visual_processors.0.conv_layers.2.weight |
|||
3network_body.visual_processors.0.conv_layers.2.bias37Conv_2"Conv* |
|||
dilations@@�* |
|||
group�* |
|||
kernel_shape@@�* |
|||
pads@ @ @ @ �* |
|||
strides@@� |
|||
1 |
|||
3738LeakyRelu_3" LeakyRelu* |
|||
alpha |
|||
�#<� |
|||
>39 |
|||
Constant_4"Constant*" |
|||
value*J�������� � |
|||
|
|||
38 |
|||
3940 Reshape_5"Reshape |
|||
� |
|||
40 |
|||
/network_body.visual_processors.0.dense.0.weight |
|||
-network_body.visual_processors.0.dense.0.bias41Gemm_6"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
1 |
|||
4142LeakyRelu_7" LeakyRelu* |
|||
alpha |
|||
�#<� |
|||
0 |
|||
4243Concat_8"Concat* |
|||
axis���������� |
|||
� |
|||
43 |
|||
/network_body.linear_encoder.seq_layers.0.weight |
|||
-network_body.linear_encoder.seq_layers.0.bias44Gemm_9"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
|
|||
4445 |
|||
Sigmoid_10"Sigmoid |
|||
|
|||
44 |
|||
4546Mul_11"Mul |
|||
� |
|||
46 |
|||
/network_body.linear_encoder.seq_layers.2.weight |
|||
-network_body.linear_encoder.seq_layers.2.bias47Gemm_12"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
|
|||
4748 |
|||
Sigmoid_13"Sigmoid |
|||
|
|||
47 |
|||
4849Mul_14"Mul |
|||
L |
|||
action_masks50Slice_15"Slice* |
|||
axes@�* |
|||
ends@�* |
|||
starts@ � |
|||
L |
|||
action_masks51Slice_16"Slice* |
|||
axes@�* |
|||
ends@�* |
|||
starts@� |
|||
� |
|||
49 |
|||
/action_model._distributions.0.branches.0.weight |
|||
-action_model._distributions.0.branches.0.bias52Gemm_17"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
* |
|||
5253 |
|||
Softmax_18"Softmax* |
|||
axis� |
|||
|
|||
53 |
|||
5054Mul_19"Mul |
|||
H |
|||
5455ReduceSum_20" ReduceSum* |
|||
axes@����������* |
|||
keepdims � |
|||
. |
|||
5556Unsqueeze_21" Unsqueeze* |
|||
axes@� |
|||
|
|||
54 |
|||
5657Div_22"Div |
|||
158Constant_23"Constant* |
|||
value*J���3� |
|||
|
|||
57 |
|||
5859Add_24"Add |
|||
|
|||
5960Log_25"Log |
|||
* |
|||
6061 |
|||
Softmax_26"Softmax* |
|||
axis� |
|||
� |
|||
49 |
|||
/action_model._distributions.0.branches.1.weight |
|||
-action_model._distributions.0.branches.1.bias62Gemm_27"Gemm* |
|||
alpha �?�* |
|||
beta �?�* |
|||
transB� |
|||
* |
|||
6263 |
|||
Softmax_28"Softmax* |
|||
axis� |
|||
|
|||
63 |
|||
5164Mul_29"Mul |
|||
H |
|||
6465ReduceSum_30" ReduceSum* |
|||
axes@����������* |
|||
keepdims � |
|||
. |
|||
6566Unsqueeze_31" Unsqueeze* |
|||
axes@� |
|||
|
|||
64 |
|||
6667Div_32"Div |
|||
168Constant_33"Constant* |
|||
value*J���3� |
|||
|
|||
67 |
|||
6869Add_34"Add |
|||
|
|||
6970Log_35"Log |
|||
* |
|||
7071 |
|||
Softmax_36"Softmax* |
|||
axis� |
|||
# |
|||
71discrete_actionsLog_37"Log |
|||
|
|||
6173Log_38"Log |
|||
|
|||
7174Log_39"Log |
|||
0 |
|||
73 |
|||
74action Concat_40"Concat* |
|||
axis� |
|||
<memory_sizeConstant_41"Constant* |
|||
value* |
|||
J �torch-jit-export*=B-action_model._distributions.0.branches.0.biasJ *��B/action_model._distributions.0.branches.0.weightJ����I�<�/-<_#
��c��
��<Ǵw<ę�.�*<��;"�<a~�;lc�<Z4��M���=�x��S"O�e�5���8���:!,G�.m�d�W�t�������^<U������;wd:�ި;S�< |