浏览代码

C# changes for hybrid action spaces (#4587)

* Add hybrid action capability flag (#4576)

* Change BrainParametersProto to support ActionSpec (#4579)

* Assign new BrainParametersProto fields based on capabilities (#4581)

* ActionBuffer with hybrid actions for RemotePolicy (#4592)

* Barracuda inference for hybrid actions (#4611)

* Refactor BarracudaModel loader checks (#4629)

* Export separate nodes for continuous/discrete actions (#4655)

* Separate continuous/discrete actions in AgentActionProto (#4698)

* Force different nodes for new and deprecated action output (#4705)
/fix-conflict-base-env
GitHub 4 年前
当前提交
94c59e31
共有 63 个文件被更改,包括 4049 次插入655 次删除
  1. 6
      com.unity.ml-agents/Runtime/Academy.cs
  2. 8
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  3. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  4. 4
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  5. 97
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  6. 33
      com.unity.ml-agents/Runtime/Agent.cs
  7. 10
      com.unity.ml-agents/Runtime/Agent.deprecated.cs
  8. 68
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  9. 2
      com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
  10. 12
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  11. 8
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  12. 82
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs
  13. 348
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs
  14. 44
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  15. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  16. 237
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  17. 4
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  18. 12
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  19. 35
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  20. 26
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  21. 15
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  22. 20
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  23. 14
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  24. 12
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  25. 3
      com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
  26. 74
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  27. 7
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  28. 62
      com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
  29. 212
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  30. 2
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  31. 2
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
  32. 22
      ml-agents-envs/mlagents_envs/communicator_objects/agent_action_pb2.py
  33. 12
      ml-agents-envs/mlagents_envs/communicator_objects/agent_action_pb2.pyi
  34. 82
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.py
  35. 45
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.pyi
  36. 13
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  37. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  38. 21
      ml-agents-envs/mlagents_envs/environment.py
  39. 18
      ml-agents-envs/mlagents_envs/mock_communicator.py
  40. 21
      ml-agents-envs/mlagents_envs/rpc_utils.py
  41. 33
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  42. 14
      ml-agents/mlagents/trainers/demo_loader.py
  43. 19
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  44. 8
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  45. 18
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  46. 28
      ml-agents/mlagents/trainers/torch/action_model.py
  47. 31
      ml-agents/mlagents/trainers/torch/model_serialization.py
  48. 43
      ml-agents/mlagents/trainers/torch/networks.py
  49. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_action.proto
  50. 14
      protobuf-definitions/proto/mlagents_envs/communicator_objects/brain_parameters.proto
  51. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
  52. 360
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  53. 11
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta
  54. 1001
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx
  55. 14
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx.meta
  56. 867
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx
  57. 14
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx.meta
  58. 462
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx
  59. 14
      com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx.meta
  60. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn
  61. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn
  62. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  63. 0
      /com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta

6
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
/// <item>
/// <term>1.3.0</term>
/// <description>Support hybrid action spaces.</description>
/// </item>
const string k_ApiVersion = "1.2.0";
const string k_ApiVersion = "1.3.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

8
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


System.Array.Clear(Array, Offset, Length);
}
/// <summary>
/// Check if the segment is empty.
/// </summary>
public bool IsEmpty()
{
return Array == null || Array.Length == 0;
}
/// <inheritdoc/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{

3
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


{
if (NumContinuousActions > 0 && NumDiscreteActions > 0)
{
throw new UnityAgentsException("ActionSpecs must be all continuous or all discrete.");
throw new UnityAgentsException("Hybrid action spaces not supported by the trainer. " +
"ActionSpecs must be all continuous or all discrete.");
}
}
}

4
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}

97
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


}
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with <see cref="ActionSpec"/>. All values are initialized to zeros.
/// /// </summary>
/// <param name="actionSpec">The <see cref="ActionSpec"/> to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSpec actionSpec)
: this(new ActionSegment<float>(new float[actionSpec.NumContinuousActions]),
new ActionSegment<int>(new int[actionSpec.NumDiscreteActions]))
{ }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with ActionSpec and all actions stored as a float array.
/// </summary>
/// <param name="actionSpec"><see cref="ActionSpec"/> of the <see cref="ActionBuffers"/></param>
/// <param name="actions">The float array of all actions, including discrete and continuous actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="ActionSpec"/> and a float array.
internal static ActionBuffers FromActionSpec(ActionSpec actionSpec, float[] actions)
{
if (actions == null)
{
return ActionBuffers.Empty;
}
Debug.Assert(actions.Length == actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions,
$"The length of '{nameof(actions)}' does not match the total size of ActionSpec.\n" +
$"{nameof(actions)}.Length: {actions.Length}\n" +
$"{nameof(actionSpec)}: {actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions}");
ActionSegment<float> continuousActionSegment = ActionSegment<float>.Empty;
ActionSegment<int> discreteActionSegment = ActionSegment<int>.Empty;
int offset = 0;
if (actionSpec.NumContinuousActions > 0)
{
continuousActionSegment = new ActionSegment<float>(actions, 0, actionSpec.NumContinuousActions);
offset += actionSpec.NumContinuousActions;
}
if (actionSpec.NumDiscreteActions > 0)
{
int[] discreteActions = new int[actionSpec.NumDiscreteActions];
for (var i = 0; i < actionSpec.NumDiscreteActions; i++)
{
discreteActions[i] = (int)actions[i + offset];
}
discreteActionSegment = new ActionSegment<int>(discreteActions);
}
return new ActionBuffers(continuousActionSegment, discreteActionSegment);
}
/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()

}
/// <summary>
/// Check if the <see cref="ActionBuffers"/> is empty.
/// </summary>
public bool IsEmpty()
{
return ContinuousActions.IsEmpty() && DiscreteActions.IsEmpty();
}
/// <inheritdoc/>
public override bool Equals(object obj)
{

unchecked
{
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}
if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
}

33
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Keeps track of the last vector action taken by the Brain.
/// </summary>
public float[] storedVectorActions;
public ActionBuffers storedVectorActions;
/// <summary>
/// For discrete control, specifies the actions that the agent cannot take.

public void ClearActions()
{
Array.Clear(storedVectorActions, 0, storedVectorActions.Length);
storedVectorActions.Clear();
actionBuffers.PackActions(storedVectorActions);
var continuousActions = storedVectorActions.ContinuousActions;
for (var i = 0; i < actionBuffers.ContinuousActions.Length; i++)
{
continuousActions[i] = actionBuffers.ContinuousActions[i];
}
var discreteActions = storedVectorActions.DiscreteActions;
for (var i = 0; i < actionBuffers.DiscreteActions.Length; i++)
{
discreteActions[i] = actionBuffers.DiscreteActions[i];
}
}
}

InitializeSensors();
}
m_Info.storedVectorActions = new float[m_ActuatorManager.TotalNumberOfActions];
m_Info.storedVectorActions = new ActionBuffers(
new float[m_ActuatorManager.NumContinuousActions],
new int[m_ActuatorManager.NumDiscreteActions]
);
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;
Array.Clear(m_Info.storedVectorActions, 0, m_Info.storedVectorActions.Length);
m_Info.storedVectorActions.Clear();
}
/// <summary>

}
else
{
m_ActuatorManager.StoredActions.PackActions(m_Info.storedVectorActions);
m_Info.CopyActions(m_ActuatorManager.StoredActions);
}
UpdateSensors();

/// </param>
public virtual void OnActionReceived(ActionBuffers actions)
{
actions.PackActions(m_LegacyActionCache);
if (!actions.ContinuousActions.IsEmpty())
{
m_LegacyActionCache = actions.ContinuousActions.Array;
}
else
{
m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
}
OnActionReceived(m_LegacyActionCache);
}

10
com.unity.ml-agents/Runtime/Agent.deprecated.cs


// [Obsolete("GetAction has been deprecated, please use GetStoredActionBuffers, Or GetStoredDiscreteActions.")]
public float[] GetAction()
{
return m_Info.storedVectorActions;
var storedAction = m_Info.storedVectorActions;
if (!storedAction.ContinuousActions.IsEmpty())
{
return storedAction.ContinuousActions.Array;
}
else
{
return Array.ConvertAll(storedAction.DiscreteActions.Array, x => (float)x);
}
}
}
}

68
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


var agentInfoProto = ai.ToAgentInfoProto();
var agentActionProto = new AgentActionProto();
if (ai.storedVectorActions != null)
if (!ai.storedVectorActions.IsEmpty())
agentActionProto.VectorActions.AddRange(ai.storedVectorActions);
if (!ai.storedVectorActions.ContinuousActions.IsEmpty())
{
agentActionProto.ContinuousActions.AddRange(ai.storedVectorActions.ContinuousActions.Array);
}
if (!ai.storedVectorActions.DiscreteActions.IsEmpty())
{
agentActionProto.DiscreteActions.AddRange(ai.storedVectorActions.DiscreteActions.Array);
}
}
return new AgentInfoActionPairProto

{
var brainParametersProto = new BrainParametersProto
{
VectorActionSize = { bp.VectorActionSize },
VectorActionSpaceType = (SpaceTypeProto)bp.VectorActionSpaceType,
VectorActionSizeDeprecated = { bp.VectorActionSize },
VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType,
brainParametersProto.VectorActionDescriptions.AddRange(bp.VectorActionDescriptions);
brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions);
}
return brainParametersProto;
}

/// <param name="isTraining">Whether or not the Brain is training.</param>
public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining)
{
actionSpec.CheckNotHybrid();
if (actionSpec.NumContinuousActions > 0)
var actionSpecProto = new ActionSpecProto
brainParametersProto.VectorActionSize.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Continuous;
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions,
};
if (actionSpec.BranchSizes != null)
{
actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes);
else if (actionSpec.NumDiscreteActions > 0)
brainParametersProto.ActionSpec = actionSpecProto;
var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions;
if (!supportHybrid)
brainParametersProto.VectorActionSize.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Discrete;
actionSpec.CheckNotHybrid();
if (actionSpec.NumContinuousActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous;
}
else if (actionSpec.NumDiscreteActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete;
}
}
// TODO handle ActionDescriptions?

{
var bp = new BrainParameters
{
VectorActionSize = bpp.VectorActionSize.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
VectorActionSize = bpp.VectorActionSizeDeprecated.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated
};
return bp;
}

}
#region AgentAction
public static List<float[]> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
public static List<ActionBuffers> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
var agentActions = new List<float[]>(proto.Value.Count);
var agentActions = new List<ActionBuffers>(proto.Value.Count);
agentActions.Add(ap.VectorActions.ToArray());
agentActions.Add(ap.ToActionBuffers());
public static ActionBuffers ToActionBuffers(this AgentActionProto proto)
{
return new ActionBuffers(proto.ContinuousActions.ToArray(), proto.DiscreteActions.ToArray());
}
#endregion
#region Observations

BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
};
}

BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
};
}

2
com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs


/// <param name="key">A key to identify which behavior actions to get.</param>
/// <param name="agentId">A key to identify which Agent actions to get.</param>
/// <returns></returns>
float[] GetActions(string key, int agentId);
ActionBuffers GetActions(string key, int agentId);
}
}

12
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


UnityRLOutputProto m_CurrentUnityRlOutput =
new UnityRLOutputProto();
Dictionary<string, Dictionary<int, float[]>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, float[]>>();
Dictionary<string, Dictionary<int, ActionBuffers>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, ActionBuffers>>();
// Brains that we have sent over the communicator with agents.
HashSet<string> m_SentBrainKeys = new HashSet<string>();

}
if (!m_LastActionsReceived.ContainsKey(behaviorName))
{
m_LastActionsReceived[behaviorName] = new Dictionary<int, float[]>();
m_LastActionsReceived[behaviorName] = new Dictionary<int, ActionBuffers>();
m_LastActionsReceived[behaviorName][info.episodeId] = null;
m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty;
if (info.done)
{
m_LastActionsReceived[behaviorName].Remove(info.episodeId);

}
}
public float[] GetActions(string behaviorName, int agentId)
public ActionBuffers GetActions(string behaviorName, int agentId)
{
if (m_LastActionsReceived.ContainsKey(behaviorName))
{

}
}
return null;
return ActionBuffers.Empty;
}
/// <summary>

8
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
public UnityRLCapabilities(
bool baseRlCapabilities = true,
bool concatenatedPngObservations = true,
bool compressedChannelMapping = true,
bool hybridActions = true)
HybridActions = hybridActions;
}
/// <summary>

82
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiSwoQQWdlbnRBY3Rp",
"b25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhINCgV2YWx1ZRgEIAEo",
"AkoECAIQA0oECAMQBEoECAUQBkIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVu",
"aWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMijAEKEEFnZW50QWN0",
"aW9uUHJvdG8SIQoZdmVjdG9yX2FjdGlvbnNfZGVwcmVjYXRlZBgBIAMoAhIN",
"CgV2YWx1ZRgEIAEoAhIaChJjb250aW51b3VzX2FjdGlvbnMYBiADKAISGAoQ",
"ZGlzY3JldGVfYWN0aW9ucxgHIAMoBUoECAIQA0oECAMQBEoECAUQBkIlqgIi",
"VW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentActionProto), global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "Value" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentActionProto), global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActionsDeprecated", "Value", "ContinuousActions", "DiscreteActions" }, null, null, null)
}));
}
#endregion

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
vectorActionsDeprecated_ = other.vectorActionsDeprecated_.Clone();
continuousActions_ = other.continuousActions_.Clone();
discreteActions_ = other.discreteActions_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
/// <summary>Field number for the "vector_actions" field.</summary>
public const int VectorActionsFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_vectorActions_codec
/// <summary>Field number for the "vector_actions_deprecated" field.</summary>
public const int VectorActionsDeprecatedFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_vectorActionsDeprecated_codec
private readonly pbc::RepeatedField<float> vectorActions_ = new pbc::RepeatedField<float>();
private readonly pbc::RepeatedField<float> vectorActionsDeprecated_ = new pbc::RepeatedField<float>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<float> VectorActions {
get { return vectorActions_; }
public pbc::RepeatedField<float> VectorActionsDeprecated {
get { return vectorActionsDeprecated_; }
}
/// <summary>Field number for the "value" field.</summary>

}
}
/// <summary>Field number for the "continuous_actions" field.</summary>
public const int ContinuousActionsFieldNumber = 6;
private static readonly pb::FieldCodec<float> _repeated_continuousActions_codec
= pb::FieldCodec.ForFloat(50);
private readonly pbc::RepeatedField<float> continuousActions_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> ContinuousActions {
get { return continuousActions_; }
}
/// <summary>Field number for the "discrete_actions" field.</summary>
public const int DiscreteActionsFieldNumber = 7;
private static readonly pb::FieldCodec<int> _repeated_discreteActions_codec
= pb::FieldCodec.ForInt32(58);
private readonly pbc::RepeatedField<int> discreteActions_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DiscreteActions {
get { return discreteActions_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentActionProto);

if (ReferenceEquals(other, this)) {
return true;
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if(!vectorActionsDeprecated_.Equals(other.vectorActionsDeprecated_)) return false;
if(!continuousActions_.Equals(other.continuousActions_)) return false;
if(!discreteActions_.Equals(other.discreteActions_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActions_.GetHashCode();
hash ^= vectorActionsDeprecated_.GetHashCode();
hash ^= continuousActions_.GetHashCode();
hash ^= discreteActions_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActions_.WriteTo(output, _repeated_vectorActions_codec);
vectorActionsDeprecated_.WriteTo(output, _repeated_vectorActionsDeprecated_codec);
continuousActions_.WriteTo(output, _repeated_continuousActions_codec);
discreteActions_.WriteTo(output, _repeated_discreteActions_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActions_.CalculateSize(_repeated_vectorActions_codec);
size += vectorActionsDeprecated_.CalculateSize(_repeated_vectorActionsDeprecated_codec);
size += continuousActions_.CalculateSize(_repeated_continuousActions_codec);
size += discreteActions_.CalculateSize(_repeated_discreteActions_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other == null) {
return;
}
vectorActions_.Add(other.vectorActions_);
vectorActionsDeprecated_.Add(other.vectorActionsDeprecated_);
continuousActions_.Add(other.continuousActions_);
discreteActions_.Add(other.discreteActions_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 10:
case 13: {
vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec);
vectorActionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionsDeprecated_codec);
break;
}
case 50:
case 53: {
continuousActions_.AddEntriesFrom(input, _repeated_continuousActions_codec);
break;
}
case 58:
case 56: {
discreteActions_.AddEntriesFrom(input, _repeated_discreteActions_codec);
break;
}
}

348
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs


"CjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2JyYWluX3Bh",
"cmFtZXRlcnMucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNtbGFnZW50",
"c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5cGUucHJvdG8i",
"2QEKFEJyYWluUGFyYW1ldGVyc1Byb3RvEhoKEnZlY3Rvcl9hY3Rpb25fc2l6",
"ZRgDIAMoBRIiChp2ZWN0b3JfYWN0aW9uX2Rlc2NyaXB0aW9ucxgFIAMoCRJG",
"Chh2ZWN0b3JfYWN0aW9uX3NwYWNlX3R5cGUYBiABKA4yJC5jb21tdW5pY2F0",
"b3Jfb2JqZWN0cy5TcGFjZVR5cGVQcm90bxISCgpicmFpbl9uYW1lGAcgASgJ",
"EhMKC2lzX3RyYWluaW5nGAggASgISgQIARACSgQIAhADSgQIBBAFQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"iwEKD0FjdGlvblNwZWNQcm90bxIeChZudW1fY29udGludW91c19hY3Rpb25z",
"GAEgASgFEhwKFG51bV9kaXNjcmV0ZV9hY3Rpb25zGAIgASgFEh0KFWRpc2Ny",
"ZXRlX2JyYW5jaF9zaXplcxgDIAMoBRIbChNhY3Rpb25fZGVzY3JpcHRpb25z",
"GAQgAygJIrYCChRCcmFpblBhcmFtZXRlcnNQcm90bxIlCh12ZWN0b3JfYWN0",
"aW9uX3NpemVfZGVwcmVjYXRlZBgDIAMoBRItCiV2ZWN0b3JfYWN0aW9uX2Rl",
"c2NyaXB0aW9uc19kZXByZWNhdGVkGAUgAygJElEKI3ZlY3Rvcl9hY3Rpb25f",
"c3BhY2VfdHlwZV9kZXByZWNhdGVkGAYgASgOMiQuY29tbXVuaWNhdG9yX29i",
"amVjdHMuU3BhY2VUeXBlUHJvdG8SEgoKYnJhaW5fbmFtZRgHIAEoCRITCgtp",
"c190cmFpbmluZxgIIAEoCBI6CgthY3Rpb25fc3BlYxgJIAEoCzIlLmNvbW11",
"bmljYXRvcl9vYmplY3RzLkFjdGlvblNwZWNQcm90b0oECAEQAkoECAIQA0oE",
"CAQQBUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG",
"cHJvdG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSize", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto), global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto.Parser, new[]{ "NumContinuousActions", "NumDiscreteActions", "DiscreteBranchSizes", "ActionDescriptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSizeDeprecated", "VectorActionDescriptionsDeprecated", "VectorActionSpaceTypeDeprecated", "BrainName", "IsTraining", "ActionSpec" }, null, null, null)
}));
}
#endregion

internal sealed partial class ActionSpecProto : pb::IMessage<ActionSpecProto> {
private static readonly pb::MessageParser<ActionSpecProto> _parser = new pb::MessageParser<ActionSpecProto>(() => new ActionSpecProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<ActionSpecProto> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto(ActionSpecProto other) : this() {
numContinuousActions_ = other.numContinuousActions_;
numDiscreteActions_ = other.numDiscreteActions_;
discreteBranchSizes_ = other.discreteBranchSizes_.Clone();
actionDescriptions_ = other.actionDescriptions_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto Clone() {
return new ActionSpecProto(this);
}
/// <summary>Field number for the "num_continuous_actions" field.</summary>
public const int NumContinuousActionsFieldNumber = 1;
private int numContinuousActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumContinuousActions {
get { return numContinuousActions_; }
set {
numContinuousActions_ = value;
}
}
/// <summary>Field number for the "num_discrete_actions" field.</summary>
public const int NumDiscreteActionsFieldNumber = 2;
private int numDiscreteActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumDiscreteActions {
get { return numDiscreteActions_; }
set {
numDiscreteActions_ = value;
}
}
/// <summary>Field number for the "discrete_branch_sizes" field.</summary>
public const int DiscreteBranchSizesFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_discreteBranchSizes_codec
= pb::FieldCodec.ForInt32(26);
private readonly pbc::RepeatedField<int> discreteBranchSizes_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DiscreteBranchSizes {
get { return discreteBranchSizes_; }
}
/// <summary>Field number for the "action_descriptions" field.</summary>
public const int ActionDescriptionsFieldNumber = 4;
private static readonly pb::FieldCodec<string> _repeated_actionDescriptions_codec
= pb::FieldCodec.ForString(34);
private readonly pbc::RepeatedField<string> actionDescriptions_ = new pbc::RepeatedField<string>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<string> ActionDescriptions {
get { return actionDescriptions_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as ActionSpecProto);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(ActionSpecProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (NumContinuousActions != other.NumContinuousActions) return false;
if (NumDiscreteActions != other.NumDiscreteActions) return false;
if(!discreteBranchSizes_.Equals(other.discreteBranchSizes_)) return false;
if(!actionDescriptions_.Equals(other.actionDescriptions_)) return false;
return Equals(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (NumContinuousActions != 0) hash ^= NumContinuousActions.GetHashCode();
if (NumDiscreteActions != 0) hash ^= NumDiscreteActions.GetHashCode();
hash ^= discreteBranchSizes_.GetHashCode();
hash ^= actionDescriptions_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (NumContinuousActions != 0) {
output.WriteRawTag(8);
output.WriteInt32(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumDiscreteActions);
}
discreteBranchSizes_.WriteTo(output, _repeated_discreteBranchSizes_codec);
actionDescriptions_.WriteTo(output, _repeated_actionDescriptions_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (NumContinuousActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumDiscreteActions);
}
size += discreteBranchSizes_.CalculateSize(_repeated_discreteBranchSizes_codec);
size += actionDescriptions_.CalculateSize(_repeated_actionDescriptions_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(ActionSpecProto other) {
if (other == null) {
return;
}
if (other.NumContinuousActions != 0) {
NumContinuousActions = other.NumContinuousActions;
}
if (other.NumDiscreteActions != 0) {
NumDiscreteActions = other.NumDiscreteActions;
}
discreteBranchSizes_.Add(other.discreteBranchSizes_);
actionDescriptions_.Add(other.actionDescriptions_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
NumContinuousActions = input.ReadInt32();
break;
}
case 16: {
NumDiscreteActions = input.ReadInt32();
break;
}
case 26:
case 24: {
discreteBranchSizes_.AddEntriesFrom(input, _repeated_discreteBranchSizes_codec);
break;
}
case 34: {
actionDescriptions_.AddEntriesFrom(input, _repeated_actionDescriptions_codec);
break;
}
}
}
}
}
internal sealed partial class BrainParametersProto : pb::IMessage<BrainParametersProto> {
private static readonly pb::MessageParser<BrainParametersProto> _parser = new pb::MessageParser<BrainParametersProto>(() => new BrainParametersProto());
private pb::UnknownFieldSet _unknownFields;

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[1]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public BrainParametersProto(BrainParametersProto other) : this() {
vectorActionSize_ = other.vectorActionSize_.Clone();
vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone();
vectorActionSpaceType_ = other.vectorActionSpaceType_;
vectorActionSizeDeprecated_ = other.vectorActionSizeDeprecated_.Clone();
vectorActionDescriptionsDeprecated_ = other.vectorActionDescriptionsDeprecated_.Clone();
vectorActionSpaceTypeDeprecated_ = other.vectorActionSpaceTypeDeprecated_;
ActionSpec = other.actionSpec_ != null ? other.ActionSpec.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
/// <summary>Field number for the "vector_action_size" field.</summary>
public const int VectorActionSizeFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSize_codec
/// <summary>Field number for the "vector_action_size_deprecated" field.</summary>
public const int VectorActionSizeDeprecatedFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSizeDeprecated_codec
private readonly pbc::RepeatedField<int> vectorActionSize_ = new pbc::RepeatedField<int>();
private readonly pbc::RepeatedField<int> vectorActionSizeDeprecated_ = new pbc::RepeatedField<int>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<int> VectorActionSize {
get { return vectorActionSize_; }
public pbc::RepeatedField<int> VectorActionSizeDeprecated {
get { return vectorActionSizeDeprecated_; }
/// <summary>Field number for the "vector_action_descriptions" field.</summary>
public const int VectorActionDescriptionsFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptions_codec
/// <summary>Field number for the "vector_action_descriptions_deprecated" field.</summary>
public const int VectorActionDescriptionsDeprecatedFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptionsDeprecated_codec
private readonly pbc::RepeatedField<string> vectorActionDescriptions_ = new pbc::RepeatedField<string>();
private readonly pbc::RepeatedField<string> vectorActionDescriptionsDeprecated_ = new pbc::RepeatedField<string>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<string> VectorActionDescriptions {
get { return vectorActionDescriptions_; }
public pbc::RepeatedField<string> VectorActionDescriptionsDeprecated {
get { return vectorActionDescriptionsDeprecated_; }
/// <summary>Field number for the "vector_action_space_type" field.</summary>
public const int VectorActionSpaceTypeFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceType_ = 0;
/// <summary>Field number for the "vector_action_space_type_deprecated" field.</summary>
public const int VectorActionSpaceTypeDeprecatedFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceTypeDeprecated_ = 0;
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceType {
get { return vectorActionSpaceType_; }
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceTypeDeprecated {
get { return vectorActionSpaceTypeDeprecated_; }
vectorActionSpaceType_ = value;
vectorActionSpaceTypeDeprecated_ = value;
}
}

}
}
/// <summary>Field number for the "action_spec" field.</summary>
public const int ActionSpecFieldNumber = 9;
private global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto actionSpec_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto ActionSpec {
get { return actionSpec_; }
set {
actionSpec_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as BrainParametersProto);

if (ReferenceEquals(other, this)) {
return true;
}
if(!vectorActionSize_.Equals(other.vectorActionSize_)) return false;
if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false;
if (VectorActionSpaceType != other.VectorActionSpaceType) return false;
if(!vectorActionSizeDeprecated_.Equals(other.vectorActionSizeDeprecated_)) return false;
if(!vectorActionDescriptionsDeprecated_.Equals(other.vectorActionDescriptionsDeprecated_)) return false;
if (VectorActionSpaceTypeDeprecated != other.VectorActionSpaceTypeDeprecated) return false;
if (!object.Equals(ActionSpec, other.ActionSpec)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActionSize_.GetHashCode();
hash ^= vectorActionDescriptions_.GetHashCode();
if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode();
hash ^= vectorActionSizeDeprecated_.GetHashCode();
hash ^= vectorActionDescriptionsDeprecated_.GetHashCode();
if (VectorActionSpaceTypeDeprecated != 0) hash ^= VectorActionSpaceTypeDeprecated.GetHashCode();
if (actionSpec_ != null) hash ^= ActionSpec.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActionSize_.WriteTo(output, _repeated_vectorActionSize_codec);
vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
vectorActionSizeDeprecated_.WriteTo(output, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptionsDeprecated_.WriteTo(output, _repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
output.WriteEnum((int) VectorActionSpaceType);
output.WriteEnum((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
output.WriteRawTag(58);

output.WriteRawTag(64);
output.WriteBool(IsTraining);
}
if (actionSpec_ != null) {
output.WriteRawTag(74);
output.WriteMessage(ActionSpec);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActionSize_.CalculateSize(_repeated_vectorActionSize_codec);
size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceType);
size += vectorActionSizeDeprecated_.CalculateSize(_repeated_vectorActionSizeDeprecated_codec);
size += vectorActionDescriptionsDeprecated_.CalculateSize(_repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName);

}
if (actionSpec_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionSpec);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

if (other == null) {
return;
}
vectorActionSize_.Add(other.vectorActionSize_);
vectorActionDescriptions_.Add(other.vectorActionDescriptions_);
if (other.VectorActionSpaceType != 0) {
VectorActionSpaceType = other.VectorActionSpaceType;
vectorActionSizeDeprecated_.Add(other.vectorActionSizeDeprecated_);
vectorActionDescriptionsDeprecated_.Add(other.vectorActionDescriptionsDeprecated_);
if (other.VectorActionSpaceTypeDeprecated != 0) {
VectorActionSpaceTypeDeprecated = other.VectorActionSpaceTypeDeprecated;
}
if (other.BrainName.Length != 0) {
BrainName = other.BrainName;

}
if (other.actionSpec_ != null) {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
ActionSpec.MergeFrom(other.ActionSpec);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 26:
case 24: {
vectorActionSize_.AddEntriesFrom(input, _repeated_vectorActionSize_codec);
vectorActionSizeDeprecated_.AddEntriesFrom(input, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptions_.AddEntriesFrom(input, _repeated_vectorActionDescriptions_codec);
vectorActionDescriptionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionDescriptionsDeprecated_codec);
vectorActionSpaceType_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
vectorActionSpaceTypeDeprecated_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
break;
}
case 58: {

case 64: {
IsTraining = input.ReadBool();
break;
}
case 74: {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
input.ReadMessage(actionSpec_);
break;
}
}

44
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
"cm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
}));
}
#endregion

baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "hybridActions" field.</summary>
public const int HybridActionsFieldNumber = 4;
private bool hybridActions_;
/// <summary>
/// support for hybrid action spaces (discrete + continuous)
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HybridActions {
get { return hybridActions_; }
set {
hybridActions_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (HybridActions != false) {
output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (HybridActions != false) {
size += 1 + 1;
}
if (_unknownFields != null) {

if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
case 32: {
HybridActions = input.ReadBool();
break;
}
}

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Inference.Utils;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;

/// </summary>
internal class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
readonly ActionSpec m_ActionSpec;
public ContinuousActionOutputApplier(ActionSpec actionSpec)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;

{
var actionValue = lastActions[agentId];
if (actionValue == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionValue = new float[actionSize];
lastActions[agentId] = actionValue;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var continuousBuffer = actionBuffer.ContinuousActions;
actionValue[j] = tensorProxy.data[agentIndex, j];
continuousBuffer[j] = tensorProxy.data[agentIndex, j];
}
}
agentIndex++;

readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;
readonly ITensorAllocator m_Allocator;
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(int[] actionSize, int seed, ITensorAllocator allocator)
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
m_ActionSize = actionSize;
m_ActionSize = actionSpec.BranchSizes;
m_ActionSpec = actionSpec;
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
//var tensorDataProbabilities = tensorProxy.Data as float[,];
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();

{
if (lastActions.ContainsKey(agentId))
{
var actionVal = lastActions[agentId];
if (actionVal == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionVal = new float[m_ActionSize.Length];
lastActions[agentId] = actionVal;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var discreteBuffer = actionBuffer.DiscreteActions;
actionVal[j] = actionValues[agentIndex, j];
discreteBuffer[j] = (int)actionValues[agentIndex, j];
}
}
agentIndex++;

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];

237
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// </summary>
internal class BarracudaModelParamLoader
{
enum ModelActionType
{
Unknown,
Discrete,
Continuous
}
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
public static IReadOnlyList<TensorProxy> GetInputTensors(Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)
{
tensors.Add(new TensorProxy
{
name = mem.input,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = TensorUtils.TensorShapeFromBarracuda(mem.shape)
});
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
public static int GetNumVisualInputs(Model model)
{
var count = 0;
if (model == null)
return count;
foreach (var input in model.inputs)
{
if (input.shape.Length == 4)
{
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
}
}
return count;
}
/// <summary>
/// Generates the Tensor outputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor outputs</returns>
public static string[] GetOutputNames(Model model)
{
var names = new List<string>();
if (model == null)
{
return names.ToArray();
}
names.Add(TensorNames.ActionOutput);
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)
{
foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort();
return names.ToArray();
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks
/// on it.
/// </summary>

return failedModelChecks;
}
foreach (var constantName in TensorNames.RequiredConstants)
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
if (!hasExpectedTensors)
var tensor = model.GetTensorByName(constantName);
if (tensor == null)
{
failedModelChecks.Add($"Required constant \"{constantName}\" was not found in the model file.");
return failedModelChecks;
}
return failedModelChecks;
var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
var isContinuousInt = (int)model.GetTensorByName(TensorNames.IsContinuousControl)[0];
var isContinuous = GetActionType(isContinuousInt);
var actionSize = (int)model.GetTensorByName(TensorNames.ActionOutputShape)[0];
if (modelApiVersion == -1)
{
failedModelChecks.Add(