浏览代码

Merge remote-tracking branch 'origin/master' into goal-conditioning

# Conflicts:
#	ml-agents-envs/mlagents_envs/base_env.py
#	ml-agents-envs/mlagents_envs/rpc_utils.py
#	ml-agents/mlagents/trainers/tests/mock_brain.py
#	ml-agents/mlagents/trainers/tests/simple_test_envs.py
/goal-conditioning
Arthur Juliani 4 年前
当前提交
0d2f8887
共有 120 个文件被更改,包括 2441 次插入1641 次删除
  1. 4
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 1
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  3. 4
      com.unity.ml-agents/CHANGELOG.md
  4. 7
      com.unity.ml-agents/Runtime/Academy.cs
  5. 8
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  6. 12
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  7. 4
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  8. 96
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  9. 34
      com.unity.ml-agents/Runtime/Agent.cs
  10. 10
      com.unity.ml-agents/Runtime/Agent.deprecated.cs
  11. 77
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  12. 2
      com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
  13. 13
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  14. 13
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  15. 82
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs
  16. 348
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs
  17. 44
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  18. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  19. 237
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  20. 4
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  21. 12
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  22. 35
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  23. 26
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  24. 15
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  25. 19
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  26. 14
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  27. 12
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  28. 3
      com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
  29. 74
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  30. 7
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  31. 62
      com.unity.ml-agents/Tests/Editor/ModelRunnerTest.cs
  32. 212
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  33. 2
      com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn.meta
  34. 2
      com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn.meta
  35. 8
      docs/Getting-Started.md
  36. 15
      docs/Learning-Environment-Create-New.md
  37. 80
      docs/Learning-Environment-Design-Agents.md
  38. 64
      docs/Python-API.md
  39. 4
      docs/Training-Configuration-File.md
  40. 10
      gym-unity/gym_unity/envs/__init__.py
  41. 147
      ml-agents-envs/mlagents_envs/base_env.py
  42. 22
      ml-agents-envs/mlagents_envs/communicator_objects/agent_action_pb2.py
  43. 12
      ml-agents-envs/mlagents_envs/communicator_objects/agent_action_pb2.pyi
  44. 82
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.py
  45. 45
      ml-agents-envs/mlagents_envs/communicator_objects/brain_parameters_pb2.pyi
  46. 13
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  47. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  48. 30
      ml-agents-envs/mlagents_envs/environment.py
  49. 18
      ml-agents-envs/mlagents_envs/mock_communicator.py
  50. 23
      ml-agents-envs/mlagents_envs/rpc_utils.py
  51. 6
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  52. 33
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  53. 27
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  54. 3
      ml-agents/mlagents/trainers/action_info.py
  55. 23
      ml-agents/mlagents/trainers/agent_processor.py
  56. 2
      ml-agents/mlagents/trainers/buffer.py
  57. 18
      ml-agents/mlagents/trainers/demo_loader.py
  58. 1
      ml-agents/mlagents/trainers/env_manager.py
  59. 4
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  60. 40
      ml-agents/mlagents/trainers/policy/policy.py
  61. 33
      ml-agents/mlagents/trainers/policy/tf_policy.py
  62. 77
      ml-agents/mlagents/trainers/policy/torch_policy.py
  63. 9
      ml-agents/mlagents/trainers/ppo/optimizer_tf.py
  64. 13
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  65. 6
      ml-agents/mlagents/trainers/sac/optimizer_tf.py
  66. 284
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  67. 2
      ml-agents/mlagents/trainers/simple_env_manager.py
  68. 56
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  69. 24
      ml-agents/mlagents/trainers/tests/mock_brain.py
  70. 99
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  71. 66
      ml-agents/mlagents/trainers/tests/tensorflow/test_ppo.py
  72. 128
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  73. 12
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py
  74. 41
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  75. 10
      ml-agents/mlagents/trainers/tests/test_demo_loader.py
  76. 33
      ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
  77. 7
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  78. 13
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  79. 2
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  80. 90
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  81. 28
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  82. 15
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py
  83. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  84. 11
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
  85. 3
      ml-agents/mlagents/trainers/tests/torch/test_sac.py
  86. 132
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  87. 44
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  88. 6
      ml-agents/mlagents/trainers/tf/components/bc/module.py
  89. 10
      ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py
  90. 17
      ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py
  91. 52
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  92. 78
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  93. 8
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  94. 37
      ml-agents/mlagents/trainers/torch/distributions.py
  95. 31
      ml-agents/mlagents/trainers/torch/model_serialization.py
  96. 226
      ml-agents/mlagents/trainers/torch/networks.py
  97. 48
      ml-agents/mlagents/trainers/torch/utils.py
  98. 28
      ml-agents/mlagents/trainers/trajectory.py
  99. 22
      ml-agents/tests/yamato/scripts/run_llapi.py
  100. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_action.proto

4
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (maskActions)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var positionX = (int)transform.localPosition.x;
var positionZ = (int)transform.localPosition.z;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
if (positionX == 0)

1
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void OnActionReceived(ActionBuffers actionBuffers)
{
m_GoalDegree += m_GoalSpeed;
UpdateGoalPosition();

4
com.unity.ml-agents/CHANGELOG.md


### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- PyTorch trainers now support training agents with both continuous and discrete action spaces.
Currently, this can only be done with Actuators. Please see
[here](../Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs) for an
example of how to use Actuators. (#4702)
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

7
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
/// <item>
/// <term>1.3.0</term>
/// <description>Support action spaces with both continuous and discrete actions.</description>
/// </item>
const string k_ApiVersion = "1.2.0";
const string k_ApiVersion = "1.3.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

Dispose();
}
}
#endif
/// <summary>

8
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


System.Array.Clear(Array, Offset, Length);
}
/// <summary>
/// Check if the segment is empty.
/// </summary>
public bool IsEmpty()
{
return Array == null || Array.Length == 0;
}
/// <inheritdoc/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{

12
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


/// </summary>
public readonly struct ActionSpec
{
/// <summary>
/// An array of branch sizes for our action space.
///

}
/// <summary>
/// Temporary check that the ActionSpec uses either all continuous or all discrete actions.
/// This should be removed once the trainer supports them.
/// Check that the ActionSpec uses either all continuous or all discrete actions.
/// This is only used when connecting to old versions of the trainer that don't support this.
internal void CheckNotHybrid()
internal void CheckAllContinuousOrDiscrete()
throw new UnityAgentsException("ActionSpecs must be all continuous or all discrete.");
throw new UnityAgentsException(
"Action spaces with both continuous and discrete actions are not supported by the trainer. " +
"ActionSpecs must be all continuous or all discrete."
);
}
}
}

4
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}

96
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


}
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with <see cref="ActionSpec"/>. All values are initialized to zeros.
/// /// </summary>
/// <param name="actionSpec">The <see cref="ActionSpec"/> to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSpec actionSpec)
: this(new ActionSegment<float>(new float[actionSpec.NumContinuousActions]),
new ActionSegment<int>(new int[actionSpec.NumDiscreteActions]))
{ }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with ActionSpec and all actions stored as a float array.
/// </summary>
/// <param name="actionSpec"><see cref="ActionSpec"/> of the <see cref="ActionBuffers"/></param>
/// <param name="actions">The float array of all actions, including discrete and continuous actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="ActionSpec"/> and a float array.
internal static ActionBuffers FromActionSpec(ActionSpec actionSpec, float[] actions)
{
if (actions == null)
{
return ActionBuffers.Empty;
}
Debug.Assert(actions.Length == actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions,
$"The length of '{nameof(actions)}' does not match the total size of ActionSpec.\n" +
$"{nameof(actions)}.Length: {actions.Length}\n" +
$"{nameof(actionSpec)}: {actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions}");
ActionSegment<float> continuousActionSegment = ActionSegment<float>.Empty;
ActionSegment<int> discreteActionSegment = ActionSegment<int>.Empty;
int offset = 0;
if (actionSpec.NumContinuousActions > 0)
{
continuousActionSegment = new ActionSegment<float>(actions, 0, actionSpec.NumContinuousActions);
offset += actionSpec.NumContinuousActions;
}
if (actionSpec.NumDiscreteActions > 0)
{
int[] discreteActions = new int[actionSpec.NumDiscreteActions];
for (var i = 0; i < actionSpec.NumDiscreteActions; i++)
{
discreteActions[i] = (int)actions[i + offset];
}
discreteActionSegment = new ActionSegment<int>(discreteActions);
}
return new ActionBuffers(continuousActionSegment, discreteActionSegment);
}
/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()

}
/// <summary>
/// Check if the <see cref="ActionBuffers"/> is empty.
/// </summary>
public bool IsEmpty()
{
return ContinuousActions.IsEmpty() && DiscreteActions.IsEmpty();
}
/// <inheritdoc/>
public override bool Equals(object obj)
{

unchecked
{
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}
if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
}

34
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Keeps track of the last vector action taken by the Brain.
/// </summary>
public float[] storedVectorActions;
public ActionBuffers storedVectorActions;
/// <summary>
/// For discrete control, specifies the actions that the agent cannot take.

public void ClearActions()
{
Array.Clear(storedVectorActions, 0, storedVectorActions.Length);
storedVectorActions.Clear();
actionBuffers.PackActions(storedVectorActions);
var continuousActions = storedVectorActions.ContinuousActions;
for (var i = 0; i < actionBuffers.ContinuousActions.Length; i++)
{
continuousActions[i] = actionBuffers.ContinuousActions[i];
}
var discreteActions = storedVectorActions.DiscreteActions;
for (var i = 0; i < actionBuffers.DiscreteActions.Length; i++)
{
discreteActions[i] = actionBuffers.DiscreteActions[i];
}
}
}

InitializeSensors();
}
m_Info.storedVectorActions = new float[m_ActuatorManager.TotalNumberOfActions];
m_Info.storedVectorActions = new ActionBuffers(
new float[m_ActuatorManager.NumContinuousActions],
new int[m_ActuatorManager.NumDiscreteActions]
);
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;
Array.Clear(m_Info.storedVectorActions, 0, m_Info.storedVectorActions.Length);
m_Info.storedVectorActions.Clear();
}
/// <summary>

}
else
{
m_ActuatorManager.StoredActions.PackActions(m_Info.storedVectorActions);
m_Info.CopyActions(m_ActuatorManager.StoredActions);
}
UpdateSensors();

/// </param>
public virtual void OnActionReceived(ActionBuffers actions)
{
actions.PackActions(m_LegacyActionCache);
if (!actions.ContinuousActions.IsEmpty())
{
m_LegacyActionCache = actions.ContinuousActions.Array;
}
else
{
m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
}
OnActionReceived(m_LegacyActionCache);
}

{
OnEpisodeBegin();
}
}
/// <summary>

10
com.unity.ml-agents/Runtime/Agent.deprecated.cs


// [Obsolete("GetAction has been deprecated, please use GetStoredActionBuffers, Or GetStoredDiscreteActions.")]
public float[] GetAction()
{
return m_Info.storedVectorActions;
var storedAction = m_Info.storedVectorActions;
if (!storedAction.ContinuousActions.IsEmpty())
{
return storedAction.ContinuousActions.Array;
}
else
{
return Array.ConvertAll(storedAction.DiscreteActions.Array, x => (float)x);
}
}
}
}

77
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


var agentInfoProto = ai.ToAgentInfoProto();
var agentActionProto = new AgentActionProto();
if (ai.storedVectorActions != null)
if (!ai.storedVectorActions.IsEmpty())
agentActionProto.VectorActions.AddRange(ai.storedVectorActions);
if (!ai.storedVectorActions.ContinuousActions.IsEmpty())
{
agentActionProto.ContinuousActions.AddRange(ai.storedVectorActions.ContinuousActions.Array);
}
if (!ai.storedVectorActions.DiscreteActions.IsEmpty())
{
agentActionProto.DiscreteActions.AddRange(ai.storedVectorActions.DiscreteActions.Array);
}
}
return new AgentInfoActionPairProto

return summariesOut;
}
#endregion
#region BrainParameters

{
var brainParametersProto = new BrainParametersProto
{
VectorActionSize = { bp.VectorActionSize },
VectorActionSpaceType = (SpaceTypeProto)bp.VectorActionSpaceType,
VectorActionSizeDeprecated = { bp.VectorActionSize },
VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType,
brainParametersProto.VectorActionDescriptions.AddRange(bp.VectorActionDescriptions);
brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions);
}
return brainParametersProto;
}

/// <param name="isTraining">Whether or not the Brain is training.</param>
public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining)
{
actionSpec.CheckNotHybrid();
if (actionSpec.NumContinuousActions > 0)
var actionSpecProto = new ActionSpecProto
brainParametersProto.VectorActionSize.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Continuous;
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions,
};
if (actionSpec.BranchSizes != null)
{
actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes);
else if (actionSpec.NumDiscreteActions > 0)
brainParametersProto.ActionSpec = actionSpecProto;
var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions;
if (!supportHybrid)
brainParametersProto.VectorActionSize.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Discrete;
actionSpec.CheckAllContinuousOrDiscrete();
if (actionSpec.NumContinuousActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous;
}
else if (actionSpec.NumDiscreteActions > 0)
{
brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete;
}
}
// TODO handle ActionDescriptions?

{
var bp = new BrainParameters
{
VectorActionSize = bpp.VectorActionSize.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
VectorActionSize = bpp.VectorActionSizeDeprecated.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated
};
return bp;
}

}
return dm;
}
#endregion
public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto)

}
#region AgentAction
public static List<float[]> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
public static List<ActionBuffers> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
var agentActions = new List<float[]>(proto.Value.Count);
var agentActions = new List<ActionBuffers>(proto.Value.Count);
agentActions.Add(ap.VectorActions.ToArray());
agentActions.Add(ap.ToActionBuffers());
public static ActionBuffers ToActionBuffers(this AgentActionProto proto)
{
return new ActionBuffers(proto.ContinuousActions.ToArray(), proto.DiscreteActions.ToArray());
}
#endregion
#region Observations

if (!s_HaveWarnedTrainerCapabilitiesMapping)
{
Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " +
"the attached trainer doesn't support compression mapping. " +
"Switching to uncompressed observations.");
"the attached trainer doesn't support compression mapping. " +
"Switching to uncompressed observations.");
s_HaveWarnedTrainerCapabilitiesMapping = true;
}
compressionType = SensorCompressionType.None;

$"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
"You must return a byte[]. If you don't want to use compressed observations, " +
"return SensorCompressionType.None from GetCompressionType()."
);
);
}
observationProto = new ObservationProto
{

observationProto.Shape.AddRange(shape);
return observationProto;
}
#endregion
public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto)

BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
};
}

BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
};
}

2
com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs


/// <param name="key">A key to identify which behavior actions to get.</param>
/// <param name="agentId">A key to identify which Agent actions to get.</param>
/// <returns></returns>
float[] GetActions(string key, int agentId);
ActionBuffers GetActions(string key, int agentId);
}
}

13
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


UnityRLOutputProto m_CurrentUnityRlOutput =
new UnityRLOutputProto();
Dictionary<string, Dictionary<int, float[]>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, float[]>>();
Dictionary<string, Dictionary<int, ActionBuffers>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, ActionBuffers>>();
// Brains that we have sent over the communicator with agents.
HashSet<string> m_SentBrainKeys = new HashSet<string>();

{
return false;
}
}
else if (unityVersion.Major != pythonVersion.Major)
{

}
if (!m_LastActionsReceived.ContainsKey(behaviorName))
{
m_LastActionsReceived[behaviorName] = new Dictionary<int, float[]>();
m_LastActionsReceived[behaviorName] = new Dictionary<int, ActionBuffers>();
m_LastActionsReceived[behaviorName][info.episodeId] = null;
m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty;
if (info.done)
{
m_LastActionsReceived[behaviorName].Remove(info.episodeId);

}
}
public float[] GetActions(string behaviorName, int agentId)
public ActionBuffers GetActions(string behaviorName, int agentId)
{
if (m_LastActionsReceived.ContainsKey(behaviorName))
{

}
}
return null;
return ActionBuffers.Empty;
}
/// <summary>

13
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
public UnityRLCapabilities(
bool baseRlCapabilities = true,
bool concatenatedPngObservations = true,
bool compressedChannelMapping = true,
bool hybridActions = true)
HybridActions = hybridActions;
}
/// <summary>

return false;
}
Debug.LogWarning("Unity has connected to a Training process that does not support" +
"Base Reinforcement Learning Capabilities. Please make sure you have the" +
" latest training codebase installed for this version of the ML-Agents package.");
"Base Reinforcement Learning Capabilities. Please make sure you have the" +
" latest training codebase installed for this version of the ML-Agents package.");
}
}

82
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiSwoQQWdlbnRBY3Rp",
"b25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhINCgV2YWx1ZRgEIAEo",
"AkoECAIQA0oECAMQBEoECAUQBkIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVu",
"aWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMijAEKEEFnZW50QWN0",
"aW9uUHJvdG8SIQoZdmVjdG9yX2FjdGlvbnNfZGVwcmVjYXRlZBgBIAMoAhIN",
"CgV2YWx1ZRgEIAEoAhIaChJjb250aW51b3VzX2FjdGlvbnMYBiADKAISGAoQ",
"ZGlzY3JldGVfYWN0aW9ucxgHIAMoBUoECAIQA0oECAMQBEoECAUQBkIlqgIi",
"VW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentActionProto), global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "Value" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentActionProto), global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActionsDeprecated", "Value", "ContinuousActions", "DiscreteActions" }, null, null, null)
}));
}
#endregion

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
vectorActionsDeprecated_ = other.vectorActionsDeprecated_.Clone();
continuousActions_ = other.continuousActions_.Clone();
discreteActions_ = other.discreteActions_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
/// <summary>Field number for the "vector_actions" field.</summary>
public const int VectorActionsFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_vectorActions_codec
/// <summary>Field number for the "vector_actions_deprecated" field.</summary>
public const int VectorActionsDeprecatedFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_vectorActionsDeprecated_codec
private readonly pbc::RepeatedField<float> vectorActions_ = new pbc::RepeatedField<float>();
private readonly pbc::RepeatedField<float> vectorActionsDeprecated_ = new pbc::RepeatedField<float>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<float> VectorActions {
get { return vectorActions_; }
public pbc::RepeatedField<float> VectorActionsDeprecated {
get { return vectorActionsDeprecated_; }
}
/// <summary>Field number for the "value" field.</summary>

}
}
/// <summary>Field number for the "continuous_actions" field.</summary>
public const int ContinuousActionsFieldNumber = 6;
private static readonly pb::FieldCodec<float> _repeated_continuousActions_codec
= pb::FieldCodec.ForFloat(50);
private readonly pbc::RepeatedField<float> continuousActions_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> ContinuousActions {
get { return continuousActions_; }
}
/// <summary>Field number for the "discrete_actions" field.</summary>
public const int DiscreteActionsFieldNumber = 7;
private static readonly pb::FieldCodec<int> _repeated_discreteActions_codec
= pb::FieldCodec.ForInt32(58);
private readonly pbc::RepeatedField<int> discreteActions_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DiscreteActions {
get { return discreteActions_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentActionProto);

if (ReferenceEquals(other, this)) {
return true;
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if(!vectorActionsDeprecated_.Equals(other.vectorActionsDeprecated_)) return false;
if(!continuousActions_.Equals(other.continuousActions_)) return false;
if(!discreteActions_.Equals(other.discreteActions_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActions_.GetHashCode();
hash ^= vectorActionsDeprecated_.GetHashCode();
hash ^= continuousActions_.GetHashCode();
hash ^= discreteActions_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActions_.WriteTo(output, _repeated_vectorActions_codec);
vectorActionsDeprecated_.WriteTo(output, _repeated_vectorActionsDeprecated_codec);
continuousActions_.WriteTo(output, _repeated_continuousActions_codec);
discreteActions_.WriteTo(output, _repeated_discreteActions_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActions_.CalculateSize(_repeated_vectorActions_codec);
size += vectorActionsDeprecated_.CalculateSize(_repeated_vectorActionsDeprecated_codec);
size += continuousActions_.CalculateSize(_repeated_continuousActions_codec);
size += discreteActions_.CalculateSize(_repeated_discreteActions_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other == null) {
return;
}
vectorActions_.Add(other.vectorActions_);
vectorActionsDeprecated_.Add(other.vectorActionsDeprecated_);
continuousActions_.Add(other.continuousActions_);
discreteActions_.Add(other.discreteActions_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 10:
case 13: {
vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec);
vectorActionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionsDeprecated_codec);
break;
}
case 50:
case 53: {
continuousActions_.AddEntriesFrom(input, _repeated_continuousActions_codec);
break;
}
case 58:
case 56: {
discreteActions_.AddEntriesFrom(input, _repeated_discreteActions_codec);
break;
}
}

348
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs


"CjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2JyYWluX3Bh",
"cmFtZXRlcnMucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNtbGFnZW50",
"c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5cGUucHJvdG8i",
"2QEKFEJyYWluUGFyYW1ldGVyc1Byb3RvEhoKEnZlY3Rvcl9hY3Rpb25fc2l6",
"ZRgDIAMoBRIiChp2ZWN0b3JfYWN0aW9uX2Rlc2NyaXB0aW9ucxgFIAMoCRJG",
"Chh2ZWN0b3JfYWN0aW9uX3NwYWNlX3R5cGUYBiABKA4yJC5jb21tdW5pY2F0",
"b3Jfb2JqZWN0cy5TcGFjZVR5cGVQcm90bxISCgpicmFpbl9uYW1lGAcgASgJ",
"EhMKC2lzX3RyYWluaW5nGAggASgISgQIARACSgQIAhADSgQIBBAFQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"iwEKD0FjdGlvblNwZWNQcm90bxIeChZudW1fY29udGludW91c19hY3Rpb25z",
"GAEgASgFEhwKFG51bV9kaXNjcmV0ZV9hY3Rpb25zGAIgASgFEh0KFWRpc2Ny",
"ZXRlX2JyYW5jaF9zaXplcxgDIAMoBRIbChNhY3Rpb25fZGVzY3JpcHRpb25z",
"GAQgAygJIrYCChRCcmFpblBhcmFtZXRlcnNQcm90bxIlCh12ZWN0b3JfYWN0",
"aW9uX3NpemVfZGVwcmVjYXRlZBgDIAMoBRItCiV2ZWN0b3JfYWN0aW9uX2Rl",
"c2NyaXB0aW9uc19kZXByZWNhdGVkGAUgAygJElEKI3ZlY3Rvcl9hY3Rpb25f",
"c3BhY2VfdHlwZV9kZXByZWNhdGVkGAYgASgOMiQuY29tbXVuaWNhdG9yX29i",
"amVjdHMuU3BhY2VUeXBlUHJvdG8SEgoKYnJhaW5fbmFtZRgHIAEoCRITCgtp",
"c190cmFpbmluZxgIIAEoCBI6CgthY3Rpb25fc3BlYxgJIAEoCzIlLmNvbW11",
"bmljYXRvcl9vYmplY3RzLkFjdGlvblNwZWNQcm90b0oECAEQAkoECAIQA0oE",
"CAQQBUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG",
"cHJvdG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSize", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto), global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto.Parser, new[]{ "NumContinuousActions", "NumDiscreteActions", "DiscreteBranchSizes", "ActionDescriptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSizeDeprecated", "VectorActionDescriptionsDeprecated", "VectorActionSpaceTypeDeprecated", "BrainName", "IsTraining", "ActionSpec" }, null, null, null)
}));
}
#endregion

internal sealed partial class ActionSpecProto : pb::IMessage<ActionSpecProto> {
private static readonly pb::MessageParser<ActionSpecProto> _parser = new pb::MessageParser<ActionSpecProto>(() => new ActionSpecProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<ActionSpecProto> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto(ActionSpecProto other) : this() {
numContinuousActions_ = other.numContinuousActions_;
numDiscreteActions_ = other.numDiscreteActions_;
discreteBranchSizes_ = other.discreteBranchSizes_.Clone();
actionDescriptions_ = other.actionDescriptions_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public ActionSpecProto Clone() {
return new ActionSpecProto(this);
}
/// <summary>Field number for the "num_continuous_actions" field.</summary>
public const int NumContinuousActionsFieldNumber = 1;
private int numContinuousActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumContinuousActions {
get { return numContinuousActions_; }
set {
numContinuousActions_ = value;
}
}
/// <summary>Field number for the "num_discrete_actions" field.</summary>
public const int NumDiscreteActionsFieldNumber = 2;
private int numDiscreteActions_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumDiscreteActions {
get { return numDiscreteActions_; }
set {
numDiscreteActions_ = value;
}
}
/// <summary>Field number for the "discrete_branch_sizes" field.</summary>
public const int DiscreteBranchSizesFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_discreteBranchSizes_codec
= pb::FieldCodec.ForInt32(26);
private readonly pbc::RepeatedField<int> discreteBranchSizes_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DiscreteBranchSizes {
get { return discreteBranchSizes_; }
}
/// <summary>Field number for the "action_descriptions" field.</summary>
public const int ActionDescriptionsFieldNumber = 4;
private static readonly pb::FieldCodec<string> _repeated_actionDescriptions_codec
= pb::FieldCodec.ForString(34);
private readonly pbc::RepeatedField<string> actionDescriptions_ = new pbc::RepeatedField<string>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<string> ActionDescriptions {
get { return actionDescriptions_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as ActionSpecProto);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(ActionSpecProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (NumContinuousActions != other.NumContinuousActions) return false;
if (NumDiscreteActions != other.NumDiscreteActions) return false;
if(!discreteBranchSizes_.Equals(other.discreteBranchSizes_)) return false;
if(!actionDescriptions_.Equals(other.actionDescriptions_)) return false;
return Equals(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (NumContinuousActions != 0) hash ^= NumContinuousActions.GetHashCode();
if (NumDiscreteActions != 0) hash ^= NumDiscreteActions.GetHashCode();
hash ^= discreteBranchSizes_.GetHashCode();
hash ^= actionDescriptions_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (NumContinuousActions != 0) {
output.WriteRawTag(8);
output.WriteInt32(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumDiscreteActions);
}
discreteBranchSizes_.WriteTo(output, _repeated_discreteBranchSizes_codec);
actionDescriptions_.WriteTo(output, _repeated_actionDescriptions_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (NumContinuousActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumContinuousActions);
}
if (NumDiscreteActions != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumDiscreteActions);
}
size += discreteBranchSizes_.CalculateSize(_repeated_discreteBranchSizes_codec);
size += actionDescriptions_.CalculateSize(_repeated_actionDescriptions_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(ActionSpecProto other) {
if (other == null) {
return;
}
if (other.NumContinuousActions != 0) {
NumContinuousActions = other.NumContinuousActions;
}
if (other.NumDiscreteActions != 0) {
NumDiscreteActions = other.NumDiscreteActions;
}
discreteBranchSizes_.Add(other.discreteBranchSizes_);
actionDescriptions_.Add(other.actionDescriptions_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
NumContinuousActions = input.ReadInt32();
break;
}
case 16: {
NumDiscreteActions = input.ReadInt32();
break;
}
case 26:
case 24: {
discreteBranchSizes_.AddEntriesFrom(input, _repeated_discreteBranchSizes_codec);
break;
}
case 34: {
actionDescriptions_.AddEntriesFrom(input, _repeated_actionDescriptions_codec);
break;
}
}
}
}
}
internal sealed partial class BrainParametersProto : pb::IMessage<BrainParametersProto> {
private static readonly pb::MessageParser<BrainParametersProto> _parser = new pb::MessageParser<BrainParametersProto>(() => new BrainParametersProto());
private pb::UnknownFieldSet _unknownFields;

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; }
get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[1]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public BrainParametersProto(BrainParametersProto other) : this() {
vectorActionSize_ = other.vectorActionSize_.Clone();
vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone();
vectorActionSpaceType_ = other.vectorActionSpaceType_;
vectorActionSizeDeprecated_ = other.vectorActionSizeDeprecated_.Clone();
vectorActionDescriptionsDeprecated_ = other.vectorActionDescriptionsDeprecated_.Clone();
vectorActionSpaceTypeDeprecated_ = other.vectorActionSpaceTypeDeprecated_;
ActionSpec = other.actionSpec_ != null ? other.ActionSpec.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
/// <summary>Field number for the "vector_action_size" field.</summary>
public const int VectorActionSizeFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSize_codec
/// <summary>Field number for the "vector_action_size_deprecated" field.</summary>
public const int VectorActionSizeDeprecatedFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_vectorActionSizeDeprecated_codec
private readonly pbc::RepeatedField<int> vectorActionSize_ = new pbc::RepeatedField<int>();
private readonly pbc::RepeatedField<int> vectorActionSizeDeprecated_ = new pbc::RepeatedField<int>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<int> VectorActionSize {
get { return vectorActionSize_; }
public pbc::RepeatedField<int> VectorActionSizeDeprecated {
get { return vectorActionSizeDeprecated_; }
/// <summary>Field number for the "vector_action_descriptions" field.</summary>
public const int VectorActionDescriptionsFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptions_codec
/// <summary>Field number for the "vector_action_descriptions_deprecated" field.</summary>
public const int VectorActionDescriptionsDeprecatedFieldNumber = 5;
private static readonly pb::FieldCodec<string> _repeated_vectorActionDescriptionsDeprecated_codec
private readonly pbc::RepeatedField<string> vectorActionDescriptions_ = new pbc::RepeatedField<string>();
private readonly pbc::RepeatedField<string> vectorActionDescriptionsDeprecated_ = new pbc::RepeatedField<string>();
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public pbc::RepeatedField<string> VectorActionDescriptions {
get { return vectorActionDescriptions_; }
public pbc::RepeatedField<string> VectorActionDescriptionsDeprecated {
get { return vectorActionDescriptionsDeprecated_; }
/// <summary>Field number for the "vector_action_space_type" field.</summary>
public const int VectorActionSpaceTypeFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceType_ = 0;
/// <summary>Field number for the "vector_action_space_type_deprecated" field.</summary>
public const int VectorActionSpaceTypeDeprecatedFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceTypeDeprecated_ = 0;
/// <summary>
/// mark as deprecated in communicator v1.3.0
/// </summary>
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceType {
get { return vectorActionSpaceType_; }
public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceTypeDeprecated {
get { return vectorActionSpaceTypeDeprecated_; }
vectorActionSpaceType_ = value;
vectorActionSpaceTypeDeprecated_ = value;
}
}

}
}
/// <summary>Field number for the "action_spec" field.</summary>
public const int ActionSpecFieldNumber = 9;
private global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto actionSpec_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto ActionSpec {
get { return actionSpec_; }
set {
actionSpec_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as BrainParametersProto);

if (ReferenceEquals(other, this)) {
return true;
}
if(!vectorActionSize_.Equals(other.vectorActionSize_)) return false;
if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false;
if (VectorActionSpaceType != other.VectorActionSpaceType) return false;
if(!vectorActionSizeDeprecated_.Equals(other.vectorActionSizeDeprecated_)) return false;
if(!vectorActionDescriptionsDeprecated_.Equals(other.vectorActionDescriptionsDeprecated_)) return false;
if (VectorActionSpaceTypeDeprecated != other.VectorActionSpaceTypeDeprecated) return false;
if (!object.Equals(ActionSpec, other.ActionSpec)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActionSize_.GetHashCode();
hash ^= vectorActionDescriptions_.GetHashCode();
if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode();
hash ^= vectorActionSizeDeprecated_.GetHashCode();
hash ^= vectorActionDescriptionsDeprecated_.GetHashCode();
if (VectorActionSpaceTypeDeprecated != 0) hash ^= VectorActionSpaceTypeDeprecated.GetHashCode();
if (actionSpec_ != null) hash ^= ActionSpec.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActionSize_.WriteTo(output, _repeated_vectorActionSize_codec);
vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
vectorActionSizeDeprecated_.WriteTo(output, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptionsDeprecated_.WriteTo(output, _repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
output.WriteEnum((int) VectorActionSpaceType);
output.WriteEnum((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
output.WriteRawTag(58);

output.WriteRawTag(64);
output.WriteBool(IsTraining);
}
if (actionSpec_ != null) {
output.WriteRawTag(74);
output.WriteMessage(ActionSpec);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActionSize_.CalculateSize(_repeated_vectorActionSize_codec);
size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceType);
size += vectorActionSizeDeprecated_.CalculateSize(_repeated_vectorActionSizeDeprecated_codec);
size += vectorActionDescriptionsDeprecated_.CalculateSize(_repeated_vectorActionDescriptionsDeprecated_codec);
if (VectorActionSpaceTypeDeprecated != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceTypeDeprecated);
}
if (BrainName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName);

}
if (actionSpec_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionSpec);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

if (other == null) {
return;
}
vectorActionSize_.Add(other.vectorActionSize_);
vectorActionDescriptions_.Add(other.vectorActionDescriptions_);
if (other.VectorActionSpaceType != 0) {
VectorActionSpaceType = other.VectorActionSpaceType;
vectorActionSizeDeprecated_.Add(other.vectorActionSizeDeprecated_);
vectorActionDescriptionsDeprecated_.Add(other.vectorActionDescriptionsDeprecated_);
if (other.VectorActionSpaceTypeDeprecated != 0) {
VectorActionSpaceTypeDeprecated = other.VectorActionSpaceTypeDeprecated;
}
if (other.BrainName.Length != 0) {
BrainName = other.BrainName;

}
if (other.actionSpec_ != null) {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
ActionSpec.MergeFrom(other.ActionSpec);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 26:
case 24: {
vectorActionSize_.AddEntriesFrom(input, _repeated_vectorActionSize_codec);
vectorActionSizeDeprecated_.AddEntriesFrom(input, _repeated_vectorActionSizeDeprecated_codec);
vectorActionDescriptions_.AddEntriesFrom(input, _repeated_vectorActionDescriptions_codec);
vectorActionDescriptionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionDescriptionsDeprecated_codec);
vectorActionSpaceType_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
vectorActionSpaceTypeDeprecated_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
break;
}
case 58: {

case 64: {
IsTraining = input.ReadBool();
break;
}
case 74: {
if (actionSpec_ == null) {
actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto();
}
input.ReadMessage(actionSpec_);
break;
}
}

44
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
"cm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
}));
}
#endregion

baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "hybridActions" field.</summary>
public const int HybridActionsFieldNumber = 4;
private bool hybridActions_;
/// <summary>
/// support for hybrid action spaces (discrete + continuous)
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HybridActions {
get { return hybridActions_; }
set {
hybridActions_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (HybridActions != false) {
output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (HybridActions != false) {
size += 1 + 1;
}
if (_unknownFields != null) {

if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
case 32: {
HybridActions = input.ReadBool();
break;
}
}

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Inference.Utils;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;

/// </summary>
internal class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
readonly ActionSpec m_ActionSpec;
public ContinuousActionOutputApplier(ActionSpec actionSpec)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;

{
var actionValue = lastActions[agentId];
if (actionValue == null)
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
actionValue = new float[actionSize];
lastActions[agentId] = actionValue;
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
var continuousBuffer = actionBuffer.ContinuousActions;
actionValue[j] = tensorProxy.data[agentIndex, j];
continuousBuffer[j] = tensorProxy.data[agentIndex, j];
}
}
agentIndex++;

readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;
readonly ITensorAllocator m_Allocator;
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(int[] actionSize, int seed, ITensorAllocator allocator)
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
m_ActionSize = actionSize;
m_ActionSize = actionSpec.BranchSizes;
m_ActionSpec = actionSpec;
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
//var tensorDataProbabilities = tensorProxy.Data as float[,];
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();

{
if (lastActions.ContainsKey(agentId))
{
var actionVal = lastActions[agentId];
if (actionVal == null)
var actionBuffer = lastActions