浏览代码

Merge branch 'develop-hybrid-actions-csharp' into develop-actionmodel-csharp

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
b0c02ee0
共有 58 个文件被更改,包括 3824 次插入532 次删除
  1. 2
      .yamato/com.unity.ml-agents-performance.yml
  2. 2
      .yamato/com.unity.ml-agents-test.yml
  3. 2
      .yamato/compressed-sensor-test.yml
  4. 2
      .yamato/gym-interface-test.yml
  5. 2
      .yamato/protobuf-generation-test.yml
  6. 2
      .yamato/python-ll-api-test.yml
  7. 2
      .yamato/standalone-build-test.yml
  8. 2
      .yamato/training-int-tests.yml
  9. 6
      com.unity.ml-agents/Runtime/Academy.cs
  10. 8
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  11. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  12. 4
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  13. 57
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  14. 44
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  15. 5
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  16. 348
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs
  17. 44
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  18. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  19. 237
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  20. 12
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  21. 35
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  22. 26
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  23. 15
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  24. 20
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  25. 14
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  26. 12
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  27. 74
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  28. 62
      com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
  29. 212
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  30. 2
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
  31. 2
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  32. 82
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.py
  33. 45
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.pyi
  34. 13
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  35. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  36. 16
      ml-agents-envs/mlagents_envs/environment.py
  37. 8
      ml-agents-envs/mlagents_envs/mock_communicator.py
  38. 9
      ml-agents-envs/mlagents_envs/rpc_utils.py
  39. 8
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  40. 8
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  41. 8
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  42. 23
      ml-agents/mlagents/trainers/torch/action_model.py
  43. 24
      ml-agents/mlagents/trainers/torch/model_serialization.py
  44. 44
      ml-agents/mlagents/trainers/torch/networks.py
  45. 14
      protobuf-definitions/proto/mlagents_envs/communicator_objects/brain_parameters.proto
  46. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
  47. 360
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  48. 11
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta
  49. 1001
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
  50. 14
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta
  51. 867
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx
  52. 14
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx.meta
  53. 462
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx
  54. 14
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx.meta
  55. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
  56. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn
  57. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  58. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn

2
.yamato/com.unity.ml-agents-performance.yml


- ./utr --suite=editor --platform=StandaloneOSX --editor-location=.Editor --testproject=DevProject --artifacts_path=build/test-results --report-performance-data --performance-project-id=com.unity.ml-agents --zero-tests-are-ok=1
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/com.unity.ml-agents-test.yml


triggers:
cancel_old_ci: true
{% if platform.name == "mac" %}
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/compressed-sensor-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/gym-interface-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/protobuf-generation-test.yml


git diff -- :/ ":(exclude,top)$CS_PROTO_PATH/*.meta" > artifacts/proto.patch; exit $GIT_ERR; }
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "protobuf-definitions/**" OR

2
.yamato/python-ll-api-test.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/standalone-build-test.yml


- python -u -m ml-agents.tests.yamato.standalone_build_tests --scene=Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureCompressed.unity
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

2
.yamato/training-int-tests.yml


- .yamato/standalone-build-test.yml#test_mac_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
# TODO remove develop-hybrid trigger before merging to master
pull_request.target match "develop-hybrid.+" OR
pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR

6
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
/// <item>
/// <term>1.3.0</term>
/// <description>Support hybrid action spaces.</description>
/// </item>
const string k_ApiVersion = "1.2.0";
const string k_ApiVersion = "1.3.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

8
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


System.Array.Clear(Array, Offset, Length);
}
/// <summary>
/// Check if the segment is empty.
/// </summary>
public bool IsEmpty()
{
return Array.Length == 0;
}
/// <inheritdoc/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{

3
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


{
if (NumContinuousActions > 0 && NumDiscreteActions > 0)
{
throw new UnityAgentsException("ActionSpecs must be all continuous or all discrete.");
throw new UnityAgentsException("Hybrid action spaces not supported by the trainer. " +
"ActionSpecs must be all continuous or all discrete.");
}
}
}

4
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}

57
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


}
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with <see cref="ActionSpec"/>. All values are initialized to zeros.
/// /// </summary>
/// <param name="actionSpec">The <see cref="ActionSpec"/> to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSpec actionSpec)
: this(new ActionSegment<float>(new float[actionSpec.NumContinuousActions]),
new ActionSegment<int>(new int[actionSpec.NumDiscreteActions]))
{ }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with ActionSpec and all actions stored as a float array.
/// </summary>
/// <param name="actionSpec"><see cref="ActionSpec"/> of the <see cref="ActionBuffers"/></param>
/// <param name="actions">The float array of all actions, including discrete and continuous actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="ActionSpec"/> and a float array.
internal static ActionBuffers FromActionSpec(ActionSpec actionSpec, float[] actions)
{
if (actions == null)
{
return new ActionBuffers(ActionSegment<float>.Empty, ActionSegment<int>.Empty);
}
Debug.Assert(actions.Length == actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions,
$"The length of '{nameof(actions)}' does not match the total size of ActionSpec.\n" +
$"{nameof(actions)}.Length: {actions.Length}\n" +
$"{nameof(actionSpec)}: {actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions}");
ActionSegment<float> continuousActionSegment = ActionSegment<float>.Empty;
ActionSegment<int> discreteActionSegment = ActionSegment<int>.Empty;
int offset = 0;
if (actionSpec.NumContinuousActions > 0)
{
continuousActionSegment = new ActionSegment<float>(actions, 0, actionSpec.NumContinuousActions);
offset += actionSpec.NumContinuousActions;
}
if (actionSpec.NumDiscreteActions > 0)
{
int[] discreteActions = new int[actionSpec.NumDiscreteActions];
for (var i = 0; i < actionSpec.NumDiscreteActions; i++)
{
discreteActions[i] = (int)actions[i + offset];
}
discreteActionSegment = new ActionSegment<int>(discreteActions);
}
return new ActionBuffers(continuousActionSegment, discreteActionSegment);
}
/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()

}
/// <summary>
/// Check if the <see cref="ActionBuffers"/> is empty.
/// </summary>
public bool IsEmpty()
{
return ContinuousActions.IsEmpty() && DiscreteActions.IsEmpty();
}
/// <inheritdoc/>

44
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


{
var brainParametersProto = new BrainParametersProto
{
VectorActionSize = { bp.VectorActionSize },
VectorActionSpaceType = (SpaceTypeProto)bp.VectorActionSpaceType,
VectorActionSizeDeprecated = { bp.VectorActionSize },
VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType,
brainParametersProto.VectorActionDescriptions.AddRange(bp.VectorActionDescriptions);
brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions);
}
return brainParametersProto;
}

/// <param name="isTraining">Whether or not the Brain is training.</param>
public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining)
{
actionSpec.CheckNotHybrid();
if (actionSpec.NumContinuousActions > 0)
var actionSpecProto = new ActionSpecProto
{
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions,
};
if (actionSpec.BranchSizes != null)
brainParametersProto.VectorActionSize.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Continuous;
actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes);
else if (actionSpec.NumDiscreteActions > 0)
brainParametersProto.ActionSpec = actionSpecProto;
var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions;
if (!supportHybrid)
brainParametersProto.VectorActionSize.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Discrete;
actionSpec.CheckNotHybrid();
if (actionSpec.NumContinuousActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous;
}
else if (actionSpec.NumDiscreteActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete;
}
}
// TODO handle ActionDescriptions?

{
var bp = new BrainParameters
{
VectorActionSize = bpp.VectorActionSize.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
VectorActionSize = bpp.VectorActionSizeDeprecated.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated
};
return bp;
}

BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
};
}

BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
};
}

5
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true,
bool compressedChannelMapping = true, bool hybridActions = true)
HybridActions = hybridActions;
}
/// <summary>

348
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs


"CjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2JyYWluX3Bh",
"cmFtZXRlcnMucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNtbGFnZW50",
"c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5cGUucHJvdG8i",
"2QEKFEJyYWluUGFyYW1ldGVyc1Byb3RvEhoKEnZlY3Rvcl9hY3Rpb25fc2l6",
"ZRgDIAMoBRIiChp2ZWN0b3JfYWN0aW9uX2Rlc2NyaXB0aW9ucxgFIAMoCRJG",
"Chh2ZWN0b3JfYWN0aW9uX3NwYWNlX3R5cGUYBiABKA4yJC5jb21tdW5pY2F0",
"b3Jfb2JqZWN0cy5TcGFjZVR5cGVQcm90bxISCgpicmFpbl9uYW1lGAcgASgJ",
"EhMKC2lzX3RyYWluaW5nGAggASgISgQIARACSgQIAhADSgQIBBAFQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"iwEKD0FjdGlvblNwZWNQcm90bxIeChZudW1fY29udGludW91c19hY3Rpb25z",
"GAEgASgFEhwKFG51bV9kaXNjcmV0ZV9hY3Rpb25zGAIgASgFEh0KFWRpc2Ny",
"ZXRlX2JyYW5jaF9zaXplcxgDIAMoBRIbChNhY3Rpb25fZGVzY3JpcHRpb25z",
"GAQgAygJIrYCChRCcmFpblBhcmFtZXRlcnNQcm90bxIlCh12ZWN0b3JfYWN0",
"aW9uX3NpemVfZGVwcmVjYXRlZBgDIAMoBRItCiV2ZWN0b3JfYWN0aW9uX2Rl",
"c2NyaXB0aW9uc19kZXByZWNhdGVkGAUgAygJElEKI3ZlY3Rvcl9hY3Rpb25f",
"c3BhY2VfdHlwZV9kZXByZWNhdGVkGAYgASgOMiQuY29tbXVuaWNhdG9yX29i",
"amVjdHMuU3BhY2VUeXBlUHJvdG8SEgoKYnJhaW5fbmFtZRgHIAEoCRITCgtp",
"c190cmFpbmluZxgIIAEoCBI6CgthY3Rpb25fc3BlYxgJIAEoCzIlLmNvbW11",
"bmljYXRvcl9vYmplY3RzLkFjdGlvblNwZWNQcm90b0oECAEQAkoECAIQA0oE",
"CAQQBUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG",
"cHJvdG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSize", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto), global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto.Parser, new[]{ "NumContinuousActions", "NumDiscreteActions", "DiscreteBranchSizes", "ActionDescriptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSizeDeprecated", "VectorActionDescriptionsDeprecated", "VectorActionSpaceTypeDeprecated", "BrainName", "IsTraining", "ActionSpec" }, null, null, null)
}));
}
#endregion

internal sealed partial class ActionSpecProto : pb::IMessage<ActionSpecProto> {
private static readonly pb::MessageParser<ActionSpecProto> _parser = new pb::MessageParser<ActionSpecProto>(() => new ActionSpecProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<ActionSpecProto> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto(ActionSpecProto other) : this() {
numContinuousActions_ = other.numContinuousActions_;
numDiscreteActions_ = other.numDiscreteActions_;
discreteBranchSizes_ = other.discreteBranchSizes_.Clone();
actionDescriptions_ = other.actionDescriptions_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto Clone() {
return new ActionSpecProto(this);
}
/// <summary>Field number for the "num_continuous_actions" field.</summary>
public const int NumContinuousActionsFieldNumber = 1;
private int numContinuousActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumContinuousActions {
get { return numContinuousActions_; }
set {
numContinuousActions_ = value;
}
}
/// <summary>Field number for the "num_discrete_actions" field.</summary>
public const int NumDiscreteActionsFieldNumber = 2;
private int numDiscreteActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumDiscreteActions {
get { return numDiscreteActions_; }
set {
numDiscreteActions_ = value;
}
}
/// <summary>Field number for the "discrete_branch_sizes" field.</summary>
public const int DiscreteBranchSizesFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_discreteBranchSizes_codec
= pb::FieldCodec.ForInt32(26);
private readonly pbc::RepeatedField<int> discreteBranchSizes_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DiscreteBranchSizes {
get { return discreteBranchSizes_; }
}
/// <summary>Field number for the "action_descriptions" field.</summary>
public const int ActionDescriptionsFieldNumber = 4;
private static readonly pb::FieldCodec<string> _repeated_actionDescriptions_codec
= pb::FieldCodec.ForString(34);
private readonly pbc::RepeatedField<string> actionDescriptions_ = new pbc::RepeatedField<string>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<string> ActionDescriptions {
get { return actionDescriptions_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as ActionSpecProto);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(ActionSpecProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (NumContinuousActions != other.NumContinuousActions) return false;
if (NumDiscreteActions != other.NumDiscreteActions) return false;
if(!discreteBranchSizes_.Equals(other.discreteBranchSizes_)) return false;
if(!actionDescriptions_.Equals(other.actionDescriptions_)) return false;
return Equals(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (NumContinuousActions != 0) hash ^= NumContinuousActions.GetHashCode();
if (NumDiscreteActions != 0) hash ^= NumDiscreteActions.GetHashCode();
hash ^= discreteBranchSizes_.GetHashCode();
hash ^= actionDescriptions_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (NumContinuousActions != 0) {
output.WriteRawTag(8);
output.WriteInt32(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumDiscreteActions);
}
discreteBranchSizes_.WriteTo(output, _repeated_discreteBranchSizes_codec);
actionDescriptions_.WriteTo(output, _repeated_actionDescriptions_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (NumContinuousActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumDiscreteActions);
}
size += discreteBranchSizes_.CalculateSize(_repeated_discreteBranchSizes_codec);
size += actionDescriptions_.CalculateSize(_repeated_actionDescriptions_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(ActionSpecProto other) {
if (other == null) {
return;
}
if (other.NumContinuousActions != 0) {
NumContinuousActions = other.NumContinuousActions;
}
if (other.NumDiscreteActions != 0) {
NumDiscreteActions = other.NumDiscreteActions;
}
discreteBranchSizes_.Add(other.discreteBranchSizes_);
actionDescriptions_.Add(other.actionDescriptions_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
NumContinuousActions = input.ReadInt32();
break;
}
case 16: {
NumDiscreteActions = input.ReadInt32();
break;
}
case 26:
case 24: {
discreteBranchSizes_.AddEntriesFrom(input, _repeated_discreteBranchSizes_codec);
break;
}
case 34: {
actionDescriptions_.AddEntriesFrom(input, _repeated_actionDescriptions_codec);
break;
}
}
}
}
}
internal sealed partial class BrainParametersProto : pb::IMessage<BrainParametersProto> {
private static readonly pb::MessageParser<BrainParametersProto> _parser = new pb::MessageParser<BrainParametersProto>(() => new BrainParametersProto());
private pb::UnknownFieldSet _unknownFields;

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[1]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public BrainParametersProto(BrainParametersProto other) : this() {
vectorActionSize_ = other.vectorActionSize_.Clone();
vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone();
vectorActionSpaceType_ = other.vectorActionSpaceType_;
vectorActionSizeDeprecated_ = other.vectorActionSizeDeprecated_.Clone();
vectorActionDescriptionsDeprecated_ = other.vectorActionDescriptionsDeprecated_.Clone();
vectorActionSpaceTypeDeprecated_ = other.vectorActionSpaceTypeDeprecated_;
ActionSpec = other.actionSpec_ != null ? other.ActionSpec.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
/// <summary>Field number for the "vector_action_size" field.</summary>
public const int VectorActionSizeFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSize_codec
/// <summary>Field number for the "vector_action_size_deprecated" field.</summary>
public const int VectorActionSizeDeprecatedFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSizeDeprecated_codec
private readonly pbc::RepeatedField<int> vectorActionSize_ = new pbc::RepeatedField<int>();
private readonly pbc::RepeatedField<int> vectorActionSizeDeprecated_ = new pbc::RepeatedField<int>();
/// <summary>
/// mark as deprecated in communicator v0.22.0
/// </summary>
public pbc::RepeatedField<int> VectorActionSize {
get { return vectorActionSize_; }
public pbc::RepeatedField<int> VectorActionSizeDeprecated {
get { return vectorActionSizeDeprecated_; }
/// <summary>Field number for the "vector_action_descriptions" field.</summary>
public const int VectorActionDescriptionsFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptions_codec
/// <summary>Field number for the "vector_action_descriptions_deprecated" field.</summary>
public const int VectorActionDescriptionsDeprecatedFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptionsDeprecated_codec
private readonly pbc::RepeatedField<string> vectorActionDescriptions_ = new pbc::RepeatedField<string>();
private readonly pbc::RepeatedField<string> vectorActionDescriptionsDeprecated_ = new pbc::RepeatedField<string>();
/// <summary>
/// mark as deprecated in communicator v0.22.0
/// </summary>
public pbc::RepeatedField<string> VectorActionDescriptions {
get { return vectorActionDescriptions_; }
public pbc::RepeatedField<string> VectorActionDescriptionsDeprecated {
get { return vectorActionDescriptionsDeprecated_; }
/// <summary>Field number for the "vector_action_space_type" field.</summary>
public const int VectorActionSpaceTypeFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceType_ = 0;
/// <summary>Field number for the "vector_action_space_type_deprecated" field.</summary>
public const int VectorActionSpaceTypeDeprecatedFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceTypeDeprecated_ = 0;
/// <summary>
/// mark as deprecated in communicator v0.22.0
/// </summary>
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceType {
get { return vectorActionSpaceType_; }
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceTypeDeprecated {
get { return vectorActionSpaceTypeDeprecated_; }
vectorActionSpaceType_ = value;
vectorActionSpaceTypeDeprecated_ = value;
}
}

}
}
/// <summary>Field number for the "action_spec" field.</summary>
public const int ActionSpecFieldNumber = 9;
private global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto actionSpec_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto ActionSpec {
get { return actionSpec_; }
set {
actionSpec_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as BrainParametersProto);

if (ReferenceEquals(other, this)) {
return true;
}
if(!vectorActionSize_.Equals(other.vectorActionSize_)) return false;
if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false;
if (VectorActionSpaceType != other.VectorActionSpaceType) return false;
if(!vectorActionSizeDeprecated_.Equals(other.vectorActionSizeDeprecated_)) return false;
if(!vectorActionDescriptionsDeprecated_.Equals(other.vectorActionDescriptionsDeprecated_)) return false;
if (VectorActionSpaceTypeDeprecated != other.VectorActionSpaceTypeDeprecated) return false;
if (!object.Equals(ActionSpec, other.ActionSpec)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActionSize_.GetHashCode();
hash ^= vectorActionDescriptions_.GetHashCode();
if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode();
hash ^= vectorActionSizeDeprecated_.GetHashCode();
hash ^= vectorActionDescriptionsDeprecated_.GetHashCode();
if (VectorActionSpaceTypeDeprecated != 0) hash ^= VectorActionSpaceTypeDeprecated.GetHashCode();
if (actionSpec_ != null) hash ^= ActionSpec.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActionSize_.WriteTo(output, _repeated_vectorActionSize_codec);
vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
vectorActionSizeDeprecated_.WriteTo(output, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptionsDeprecated_.WriteTo(output, _repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
output.WriteEnum((int) VectorActionSpaceType);
output.WriteEnum((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
output.WriteRawTag(58);

output.WriteRawTag(64);
output.WriteBool(IsTraining);
}
if (actionSpec_ != null) {
output.WriteRawTag(74);
output.WriteMessage(ActionSpec);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActionSize_.CalculateSize(_repeated_vectorActionSize_codec);
size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceType);
size += vectorActionSizeDeprecated_.CalculateSize(_repeated_vectorActionSizeDeprecated_codec);
size += vectorActionDescriptionsDeprecated_.CalculateSize(_repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName);

}
if (actionSpec_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionSpec);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

if (other == null) {
return;
}
vectorActionSize_.Add(other.vectorActionSize_);
vectorActionDescriptions_.Add(other.vectorActionDescriptions_);
if (other.VectorActionSpaceType != 0) {
VectorActionSpaceType = other.VectorActionSpaceType;
vectorActionSizeDeprecated_.Add(other.vectorActionSizeDeprecated_);
vectorActionDescriptionsDeprecated_.Add(other.vectorActionDescriptionsDeprecated_);
if (other.VectorActionSpaceTypeDeprecated != 0) {
VectorActionSpaceTypeDeprecated = other.VectorActionSpaceTypeDeprecated;
}
if (other.BrainName.Length != 0) {
BrainName = other.BrainName;

}
if (other.actionSpec_ != null) {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
ActionSpec.MergeFrom(other.ActionSpec);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 26:
case 24: {
vectorActionSize_.AddEntriesFrom(input, _repeated_vectorActionSize_codec);
vectorActionSizeDeprecated_.AddEntriesFrom(input, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptions_.AddEntriesFrom(input, _repeated_vectorActionDescriptions_codec);
vectorActionDescriptionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionDescriptionsDeprecated_codec);
vectorActionSpaceType_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
vectorActionSpaceTypeDeprecated_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
break;
}
case 58: {

case 64: {
IsTraining = input.ReadBool();
break;
}
case 74: {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
input.ReadMessage(actionSpec_);
break;
}
}

44
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
"cm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
}));
}
#endregion

baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "hybridActions" field.</summary>
public const int HybridActionsFieldNumber = 4;
private bool hybridActions_;
/// <summary>
/// support for hybrid action spaces (discrete + continuous)
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HybridActions {
get { return hybridActions_; }
set {
hybridActions_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (HybridActions != false) {
output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (HybridActions != false) {
size += 1 + 1;
}
if (_unknownFields != null) {

if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
case 32: {
HybridActions = input.ReadBool();
break;
}
}

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Inference.Utils;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;

/// </summary>
internal class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
readonly ActionSpec m_ActionSpec;
public ContinuousActionOutputApplier(ActionSpec actionSpec)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;

{
var actionValue = lastActions[agentId];
if (actionValue == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionValue = new float[actionSize];
lastActions[agentId] = actionValue;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var continuousBuffer = actionBuffer.ContinuousActions;
actionValue[j] = tensorProxy.data[agentIndex, j];
continuousBuffer[j] = tensorProxy.data[agentIndex, j];
}
}
agentIndex++;

readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;
readonly ITensorAllocator m_Allocator;
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(int[] actionSize, int seed, ITensorAllocator allocator)
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
m_ActionSize = actionSize;
m_ActionSize = actionSpec.BranchSizes;
m_ActionSpec = actionSpec;
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
//var tensorDataProbabilities = tensorProxy.Data as float[,];
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();

{
if (lastActions.ContainsKey(agentId))
{
var actionVal = lastActions[agentId];
if (actionVal == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionVal = new float[m_ActionSize.Length];
lastActions[agentId] = actionVal;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var discreteBuffer = actionBuffer.DiscreteActions;
actionVal[j] = actionValues[agentIndex, j];
discreteBuffer[j] = (int)actionValues[agentIndex, j];
}
}
agentIndex++;

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

237
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// </summary>
internal class BarracudaModelParamLoader
{
enum ModelActionType
{
Unknown,
Discrete,
Continuous
}
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
public static IReadOnlyList<TensorProxy> GetInputTensors(Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)
{
tensors.Add(new TensorProxy
{
name = mem.input,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = TensorUtils.TensorShapeFromBarracuda(mem.shape)
});
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
public static int GetNumVisualInputs(Model model)
{
var count = 0;
if (model == null)
return count;
foreach (var input in model.inputs)
{
if (input.shape.Length == 4)
{
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
}
}
return count;
}
/// <summary>
/// Generates the Tensor outputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor outputs</returns>
public static string[] GetOutputNames(Model model)
{
var names = new List<string>();
if (model == null)
{
return names.ToArray();
}
names.Add(TensorNames.ActionOutput);
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)
{
foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort();
return names.ToArray();
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks
/// on it.
/// </summary>

return failedModelChecks;
}
foreach (var constantName in TensorNames.RequiredConstants)
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
if (!hasExpectedTensors)
var tensor = model.GetTensorByName(constantName);
if (tensor == null)
{
failedModelChecks.Add($"Required constant \"{constantName}\" was not found in the model file.");
return failedModelChecks;
}
return failedModelChecks;
var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
var isContinuousInt = (int)model.GetTensorByName(TensorNames.IsContinuousControl)[0];
var isContinuous = GetActionType(isContinuousInt);
var actionSize = (int)model.GetTensorByName(TensorNames.ActionOutputShape)[0];
if (modelApiVersion == -1)
{
failedModelChecks.Add(

return failedModelChecks;
}
var modelDiscreteActionSize = isContinuous == ModelActionType.Discrete ? actionSize : 0;
var modelContinuousActionSize = isContinuous == ModelActionType.Continuous ? actionSize : 0;
var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memorySize == -1)
{
failedModelChecks.Add($"Missing node in the model provided : {TensorNames.MemorySize}");
return failedModelChecks;
}
CheckIntScalarPresenceHelper(new Dictionary<string, int>()
{
{TensorNames.MemorySize, memorySize},
{TensorNames.IsContinuousControl, isContinuousInt},
{TensorNames.ActionOutputShape, actionSize}
})
CheckInputTensorPresence(model, brainParameters, memorySize, sensorComponents)
CheckInputTensorPresence(model, brainParameters, memorySize, isContinuous, sensorComponents)
CheckOutputTensorPresence(model, memorySize)
failedModelChecks.AddRange(
CheckOutputTensorPresence(model, memorySize))
;
CheckOutputTensorShape(model, brainParameters, actuatorComponents, isContinuous, modelContinuousActionSize, modelDiscreteActionSize)
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
/// Converts the integer value in the model corresponding to the type of control to a
/// ModelActionType.
/// </summary>
/// <param name="isContinuousInt">
/// The integer value in the model indicating the type of control
/// </param>
/// <returns>The equivalent ModelActionType</returns>
static ModelActionType GetActionType(int isContinuousInt)
{
ModelActionType isContinuous;
switch (isContinuousInt)
{
case 0:
isContinuous = ModelActionType.Discrete;
break;
case 1:
isContinuous = ModelActionType.Continuous;
break;
default:
isContinuous = ModelActionType.Unknown;
break;
}
return isContinuous;
}
/// <summary>
/// Given a Dictionary of node names to int values, create checks if the values have the
/// invalid value of -1.
/// </summary>
/// <param name="requiredScalarFields"> Mapping from node names to int values</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckIntScalarPresenceHelper(
Dictionary<string, int> requiredScalarFields)
{
var failedModelChecks = new List<string>();
foreach (var field in requiredScalarFields)
{
if (field.Value == -1)
{
failedModelChecks.Add($"Missing node in the model provided : {field.Key}");
}
}
return failedModelChecks;
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>

Model model,
BrainParameters brainParameters,
int memory,
ModelActionType isContinuous,
var tensorsNames = GetInputTensors(model).Select(x => x.name).ToList();
var tensorsNames = model.GetInputNames();
// If there is no Vector Observation Input but the Brain Parameters expect one.
if ((brainParameters.VectorObservationSize != 0) &&

"The model does not contain a Vector Observation Placeholder Input. " +
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.");
}

visObsIndex++;
}
var expectedVisualObs = GetNumVisualInputs(model);
var expectedVisualObs = model.GetNumVisualInputs();
// Check if there's not enough visual sensors (too many would be handled above)
if (expectedVisualObs > visObsIndex)
{

}
// If the model uses discrete control but does not have an input for action masks
if (isContinuous == ModelActionType.Discrete)
if (model.HasDiscreteOutputs())
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{

static IEnumerable<string> CheckOutputTensorPresence(Model model, int memory)
{
var failedModelChecks = new List<string>();
// If there is no Action Output.
if (!model.outputs.Contains(TensorNames.ActionOutput))
{
failedModelChecks.Add("The model does not contain an Action Output Node.");
}
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)

}
// If the model expects an input but it is not in this list
foreach (var tensor in GetInputTensors(model))
foreach (var tensor in model.GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.name))
{

BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
{
// TODO: Update this check after intergrating ActionSpec into BrainParameters
var numberActionsBp = brainParameters.VectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];
if (numberActionsBp != numberActionsT)

static IEnumerable<string> CheckOutputTensorShape(
Model model,
BrainParameters brainParameters,
ActuatorComponent[] actuatorComponents,
ModelActionType isContinuous,
int modelContinuousActionSize, int modelSumDiscreteBranchSizes)
ActuatorComponent[] actuatorComponents)
if (isContinuous == ModelActionType.Unknown)
{
failedModelChecks.Add("Cannot infer type of Control from the provided model.");
return failedModelChecks;
}
if (isContinuous == ModelActionType.Continuous &&
brainParameters.VectorActionSpaceType != SpaceType.Continuous)
{
failedModelChecks.Add(
"Model has been trained using Continuous Control but the Brain Parameters " +
"suggest Discrete Control.");
return failedModelChecks;
}
if (isContinuous == ModelActionType.Discrete &&
brainParameters.VectorActionSpaceType != SpaceType.Discrete)
{
failedModelChecks.Add(
"Model has been trained using Discrete Control but the Brain Parameters " +
"suggest Continuous Control.");
return failedModelChecks;
}
// This will need to change a bit for hybrid action spaces.
if (isContinuous == ModelActionType.Continuous)
if (model.HasContinuousOutputs())
tensorTester[TensorNames.ActionOutput] = CheckContinuousActionOutputShape;
tensorTester[model.ContinuousOutputName()] = CheckContinuousActionOutputShape;
else
if (model.HasDiscreteOutputs())
tensorTester[TensorNames.ActionOutput] = CheckDiscreteActionOutputShape;
tensorTester[model.DiscreteOutputName()] = CheckDiscreteActionOutputShape;
var modelContinuousActionSize = model.ContinuousOutputSize();
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
foreach (var name in model.outputs)
{
if (tensorTester.ContainsKey(name))

12
com.unity.ml-agents/Runtime/Inference/ModelRunner.cs


internal class ModelRunner
{
List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>();
Dictionary<int, float[]> m_LastActionsReceived = new Dictionary<int, float[]>();
Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>();
List<int> m_OrderedAgentsRequestingDecisions = new List<int>();
ITensorAllocator m_TensorAllocator;

m_Engine = null;
}
m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel);
m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel);
m_InferenceInputs = barracudaModel.GetInputTensors();
m_OutputNames = barracudaModel.GetOutputNames();
m_TensorGenerator = new TensorGenerator(
seed, m_TensorAllocator, m_Memories, barracudaModel);
m_TensorApplier = new TensorApplier(

if (!m_LastActionsReceived.ContainsKey(info.episodeId))
{
m_LastActionsReceived[info.episodeId] = null;
m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty;
}
if (info.done)
{

return m_Model == other && m_InferenceDevice == otherInferenceDevice;
}
public float[] GetAction(int agentId)
public ActionBuffers GetAction(int agentId)
return null;
return ActionBuffers.Empty;
}
}
}

35
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


/// </param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions);
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
}
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();

Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
actionSpec.CheckNotHybrid();
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
if (model.UseDeprecated())
{
actionSpec.CheckNotHybrid();
}
m_Dict[TensorNames.ActionOutput] = new ContinuousActionOutputApplier();
var tensorName = model.ContinuousOutputName();
m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec);
else
if (actionSpec.NumDiscreteActions > 0)
m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(actionSpec.BranchSizes, seed, allocator);
var tensorName = model.DiscreteOutputName();
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
if (barracudaModel != null)
for (var i = 0; i < model?.memories.Count; i++)
var model = (Model)barracudaModel;
for (var i = 0; i < model?.memories.Count; i++)
{
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
foreach (var tensor in tensors)
{

26
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


Dictionary<int, List<float>> memories,
object barracudaModel = null)