浏览代码

Merge pull request #4732 from Unity-Technologies/goal-sensors

Adds SensorTypes and GoalSensors
/goal-conditioning
GitHub 3 年前
当前提交
ded1f79b
共有 40 个文件被更改,包括 518 次插入113 次删除
  1. 6
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  2. 6
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  3. 6
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  4. 6
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  5. 6
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  6. 1
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  7. 50
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  8. 9
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  9. 29
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  10. 6
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  11. 6
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  12. 9
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  13. 6
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  14. 6
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  15. 6
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  16. 5
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  17. 5
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  18. 5
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  19. 5
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  20. 6
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  21. 4
      gym-unity/gym_unity/tests/test_gym.py
  22. 10
      ml-agents-envs/mlagents_envs/base_env.py
  23. 51
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  24. 25
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  25. 19
      ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2.py
  26. 81
      ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2_grpc.py
  27. 4
      ml-agents-envs/mlagents_envs/rpc_utils.py
  28. 26
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  29. 11
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  30. 8
      ml-agents/mlagents/trainers/tests/mock_brain.py
  31. 8
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  32. 7
      ml-agents/mlagents/trainers/tests/tensorflow/test_models.py
  33. 4
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py
  34. 45
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  35. 14
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
  36. 32
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  37. 32
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
  38. 7
      protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
  39. 48
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs
  40. 11
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs.meta

6
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;

6
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
var compressed = m_Texture.EncodeToPNG();

6
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)

6
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.WriteToTensor"))

6
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public void Update()
{
if (m_Settings.UseModelSpace)

1
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
observationProto.SensorType = (SensorTypeProto)sensor.GetSensorType();
observationProto.Shape.AddRange(shape);
return observationProto;
}

50
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKdAgoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyLZAgoQT2JzZXJ2YXRp",
"KAUaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f",
"ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H",
"EAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
"b3RvMw=="));
"KAUSOgoLc2Vuc29yX3R5cGUYBiABKA4yJS5jb21tdW5pY2F0b3Jfb2JqZWN0",
"cy5TZW5zb3JUeXBlUHJvdG8aGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJC",
"EgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxII",
"CgROT05FEAASBwoDUE5HEAEqOAoPU2Vuc29yVHlwZVByb3RvEg8KC09CU0VS",
"VkFUSU9OEAASCAoER09BTBABEgoKBlJFV0FSRBACQiWqAiJVbml0eS5NTEFn",
"ZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "SensorType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion

internal enum CompressionTypeProto {
[pbr::OriginalName("NONE")] None = 0,
[pbr::OriginalName("PNG")] Png = 1,
}
internal enum SensorTypeProto {
[pbr::OriginalName("OBSERVATION")] Observation = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
}
#endregion

shape_ = other.shape_.Clone();
compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
sensorType_ = other.sensorType_;
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

get { return compressedChannelMapping_; }
}
/// <summary>Field number for the "sensor_type" field.</summary>
public const int SensorTypeFieldNumber = 6;
private global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto sensorType_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto SensorType {
get { return sensorType_; }
set {
sensorType_ = value;
}
}
private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {

if (CompressedData != other.CompressedData) return false;
if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if (SensorType != other.SensorType) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
if (SensorType != 0) hash ^= SensorType.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

output.WriteMessage(FloatData);
}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
if (SensorType != 0) {
output.WriteRawTag(48);
output.WriteEnum((int) SensorType);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData);
}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
if (SensorType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) SensorType);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

CompressionType = other.CompressionType;
}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
if (other.SensorType != 0) {
SensorType = other.SensorType;
}
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

case 42:
case 40: {
compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec);
break;
}
case 48: {
sensorType_ = (global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto) input.ReadEnum();
break;
}
}

9
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


}
/// <summary>
/// Camera sensors are always Observations.
/// </summary>
/// <returns>Sensor type of observation.</returns>
public SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <summary>
/// Generates a compressed image. This can be valuable in speeding-up training.
/// </summary>
/// <returns>Compressed image.</returns>

29
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


}
/// <summary>
/// The semantic meaning of the sensor.
/// </summary>
public enum SensorType
{
/// <summary>
/// Sensor represents an agent's observation.
/// </summary>
Observation,
/// <summary>
/// Sensor represents an agent's task/goal parameterization.
/// </summary>
Goal,
/// <summary>
/// Sensor represents one or more reward signals.
/// </summary>
Reward
}
/// <summary>
/// Sensor interface for generating observations.
/// </summary>
public interface ISensor

/// </summary>
/// <returns>The name of the sensor.</returns>
string GetName();
/// <summary>
/// Get the semantic meaning of the sensor, i.e. whether it is an observation or other type
/// of data to be sent to the Agent.
/// </summary>
/// <returns>The type of the sensor.</returns>
SensorType GetSensorType();
}

6
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


}
/// <inheritdoc/>
public SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;

6
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


}
/// <inheritdoc/>
public SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public void Update() { }
/// <inheritdoc/>

9
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


}
/// <summary>
/// RenderTexture sensors are always Observations.
/// </summary>
/// <returns>Sensor type of observation.</returns>
public SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <summary>
/// Converts a RenderTexture to a 2D texture.
/// </summary>
/// <returns>The 2D texture.</returns>

6
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


return m_WrappedSensor.GetCompressionType();
}
/// <inheritdoc/>
public SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <summary>
/// Create Empty PNG for initializing the buffer for stacking.
/// </summary>

6
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
{
return null;

6
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


return new byte[] { 13, 37 };
}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
public void Update() { }
public void Reset() { }

5
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


return new byte[] { 0 };
}
public SensorType GetSensorType()
{
return SensorType.Observation;
}
public SensorCompressionType GetCompressionType()
{
return compressionType;

5
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


return new[] { m_Height, m_Width, m_Channels };
}
public SensorType GetSensorType()
{
return SensorType.Observation;
}
public int Write(ObservationWriter writer)
{
for (int i = 0; i < m_Width * m_Height * m_Channels; i++)

5
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


m_Shape = new[] { Height, Width, 1 };
}
public SensorType GetSensorType()
{
return SensorType.Observation;
}
public string GetName()
{
return m_Name;

5
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


return m_Shape;
}
public SensorType GetSensorType()
{
return SensorType.Observation;
}
public byte[] GetCompressedObservation()
{
return null;

6
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


return Shape;
}
/// <inheritdoc/>
public virtual SensorType GetSensorType()
{
return SensorType.Observation;
}
public int Write(ObservationWriter writer)
{
for (var h = 0; h < Shape[0]; h++)

4
gym-unity/gym_unity/tests/test_gym.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionSpec,
SensorType,
DecisionSteps,
TerminalSteps,
BehaviorMapping,

obs_shapes = [(vector_observation_space_size,)]
for _ in range(number_visual_observations):
obs_shapes += [(8, 8, 3)]
return BehaviorSpec(obs_shapes, action_spec)
sensor_types = [SensorType.OBSERVATION for _ in range(len(obs_shapes))]
return BehaviorSpec(obs_shapes, sensor_types, action_spec)
def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):

10
ml-agents-envs/mlagents_envs/base_env.py


from abc import ABC, abstractmethod
from collections.abc import Mapping
from enum import Enum
from typing import (
List,
NamedTuple,

)
class SensorType(Enum):
OBSERVATION = 0
PARAMETERIZATION = 1
REWARD = 2
class ActionSpec(NamedTuple):
"""
A NamedTuple containing utility functions and information about the action spaces

- observation_shapes is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions. The shape tuples have the same ordering as
the ordering of the DecisionSteps and TerminalSteps.
- sensor_types is a List of SensorTypes, each corresponding to the type of
sensor (i.e. observation, goal, etc).
sensor_types: List[SensorType]
action_spec: ActionSpec

51
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x9d\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xd9\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12:\n\x0bsensor_type\x18\x06 \x01(\x0e\x32%.communicator_objects.SensorTypeProto\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*8\n\x0fSensorTypeProto\x12\x0f\n\x0bOBSERVATION\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

],
containing_type=None,
options=None,
serialized_start=366,
serialized_end=407,
serialized_start=426,
serialized_end=467,
_SENSORTYPEPROTO = _descriptor.EnumDescriptor(
name='SensorTypeProto',
full_name='communicator_objects.SensorTypeProto',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='OBSERVATION', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='GOAL', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='REWARD', index=2, number=2,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=469,
serialized_end=525,
)
_sym_db.RegisterEnumDescriptor(_SENSORTYPEPROTO)
SensorTypeProto = enum_type_wrapper.EnumTypeWrapper(_SENSORTYPEPROTO)
OBSERVATION = 0
GOAL = 1
REWARD = 2

extension_ranges=[],
oneofs=[
],
serialized_start=319,
serialized_end=344,
serialized_start=379,
serialized_end=404,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='sensor_type', full_name='communicator_objects.ObservationProto.sensor_type', index=5,
number=6, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

index=0, containing_type=None, fields=[]),
],
serialized_start=79,
serialized_end=364,
serialized_end=424,
_OBSERVATIONPROTO.fields_by_name['sensor_type'].enum_type = _SENSORTYPEPROTO
_OBSERVATIONPROTO.oneofs_by_name['observation_data'].fields.append(
_OBSERVATIONPROTO.fields_by_name['compressed_data'])
_OBSERVATIONPROTO.fields_by_name['compressed_data'].containing_oneof = _OBSERVATIONPROTO.oneofs_by_name['observation_data']

DESCRIPTOR.message_types_by_name['ObservationProto'] = _OBSERVATIONPROTO
DESCRIPTOR.enum_types_by_name['CompressionTypeProto'] = _COMPRESSIONTYPEPROTO
DESCRIPTOR.enum_types_by_name['SensorTypeProto'] = _SENSORTYPEPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
ObservationProto = _reflection.GeneratedProtocolMessageType('ObservationProto', (_message.Message,), dict(

25
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


NONE = typing___cast('CompressionTypeProto', 0)
PNG = typing___cast('CompressionTypeProto', 1)
class SensorTypeProto(builtin___int):
DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ...
@classmethod
def Name(cls, number: builtin___int) -> builtin___str: ...
@classmethod
def Value(cls, name: builtin___str) -> 'SensorTypeProto': ...
@classmethod
def keys(cls) -> typing___List[builtin___str]: ...
@classmethod
def values(cls) -> typing___List['SensorTypeProto']: ...
@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'SensorTypeProto']]: ...
OBSERVATION = typing___cast('SensorTypeProto', 0)
GOAL = typing___cast('SensorTypeProto', 1)
REWARD = typing___cast('SensorTypeProto', 2)
OBSERVATION = typing___cast('SensorTypeProto', 0)
GOAL = typing___cast('SensorTypeProto', 1)
REWARD = typing___cast('SensorTypeProto', 2)
class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class FloatData(google___protobuf___message___Message):

compression_type = ... # type: CompressionTypeProto
compressed_data = ... # type: builtin___bytes
compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
sensor_type = ... # type: SensorTypeProto
@property
def float_data(self) -> ObservationProto.FloatData: ...

compressed_data : typing___Optional[builtin___bytes] = None,
float_data : typing___Optional[ObservationProto.FloatData] = None,
compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
sensor_type : typing___Optional[SensorTypeProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"sensor_type",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"sensor_type",b"sensor_type",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...

19
ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2.py


# -*- coding: utf-8 -*-
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
"""Generated protocol buffer code."""
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()

name='mlagents_envs/communicator_objects/unity_to_external.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n:mlagents_envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents_envs/communicator_objects/unity_message.proto2v\n\x14UnityToExternalProto\x12^\n\x08\x45xchange\x12\'.communicator_objects.UnityMessageProto\x1a\'.communicator_objects.UnityMessageProto\"\x00\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_options=b'\252\002\"Unity.MLAgents.CommunicatorObjects',
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n:mlagents_envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents_envs/communicator_objects/unity_message.proto2v\n\x14UnityToExternalProto\x12^\n\x08\x45xchange\x12\'.communicator_objects.UnityMessageProto\x1a\'.communicator_objects.UnityMessageProto\"\x00\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3'
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.DESCRIPTOR,])

DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\"Unity.MLAgents.CommunicatorObjects'))
DESCRIPTOR._options = None
_UNITYTOEXTERNALPROTO = _descriptor.ServiceDescriptor(
name='UnityToExternalProto',

options=None,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=140,
serialized_end=258,
methods=[

containing_service=None,
input_type=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGEPROTO,
output_type=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGEPROTO,
options=None,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNALPROTO)

81
ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2_grpc.py


# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from mlagents_envs.communicator_objects import unity_message_pb2 as mlagents__envs_dot_communicator__objects_dot_unity__message__pb2

# missing associated documentation comment in .proto file
pass
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Exchange = channel.unary_unary(
'/communicator_objects.UnityToExternalProto/Exchange',
request_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
response_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
)
Args:
channel: A grpc.Channel.
"""
self.Exchange = channel.unary_unary(
'/communicator_objects.UnityToExternalProto/Exchange',
request_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
response_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
)
# missing associated documentation comment in .proto file
pass
"""Missing associated documentation comment in .proto file."""
def Exchange(self, request, context):
"""Sends the academy parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Exchange(self, request, context):
"""Sends the academy parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
rpc_method_handlers = {
'Exchange': grpc.unary_unary_rpc_method_handler(
servicer.Exchange,
request_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
response_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'communicator_objects.UnityToExternalProto', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
rpc_method_handlers = {
'Exchange': grpc.unary_unary_rpc_method_handler(
servicer.Exchange,
request_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
response_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'communicator_objects.UnityToExternalProto', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class UnityToExternalProto(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Exchange(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/communicator_objects.UnityToExternalProto/Exchange',
mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

4
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionSpec,
SensorType,
DecisionSteps,
TerminalSteps,
)

:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
sensor_type = [SensorType(obs.sensor_type) for obs in agent_info.observations]
return BehaviorSpec(observation_shape, action_spec)
return BehaviorSpec(observation_shape, sensor_type, action_spec)
class OffsetBytesIO:

26
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionSpec,
SensorType,
DecisionSteps,
TerminalSteps,
)

def test_batched_step_result_from_proto():
n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, spec)
for agent_id in range(n_agents):

def test_action_masking_discrete():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((7, 3)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(
shapes, sensor_type, ActionSpec.create_discrete((7, 3))
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_1():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((10,)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_discrete((10,)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_2():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((2, 2, 6)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(
shapes, sensor_type, ActionSpec.create_discrete((2, 2, 6))
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_continuous():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(10))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(10))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_batched_step_result_from_proto_raises_on_infinite():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, infinite_rewards=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, nan_observations=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

11
ml-agents-envs/mlagents_envs/tests/test_steps.py


TerminalSteps,
ActionSpec,
BehaviorSpec,
SensorType,
)

def test_empty_decision_steps():
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
observation_shapes=[(3, 2), (5,)],
sensor_types=sensor_type,
action_spec=ActionSpec.create_continuous(3),
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2

def test_empty_terminal_steps():
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
observation_shapes=[(3, 2), (5,)],
sensor_types=sensor_type,
action_spec=ActionSpec.create_continuous(3),
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2

8
ml-agents/mlagents/trainers/tests/mock_brain.py


TerminalSteps,
BehaviorSpec,
ActionSpec,
SensorType,
)

obs_list = []
for _shape in observation_shapes:
obs_list.append(np.ones((num_agents,) + _shape, dtype=np.float32))
sensor_types = [SensorType.OBSERVATION for i in range(len(obs_list))]
action_mask = None
if action_spec.is_discrete():
action_mask = [

reward = np.array(num_agents * [1.0], dtype=np.float32)
interrupted = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)
behavior_spec = BehaviorSpec(observation_shapes, action_spec)
behavior_spec = BehaviorSpec(observation_shapes, sensor_types, action_spec)
if done:
return (
DecisionSteps.empty(behavior_spec),

else:
action_spec = ActionSpec.create_continuous(vector_action_space)
behavior_spec = BehaviorSpec(
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)], action_spec
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)],
[SensorType.OBSERVATION],
action_spec,
)
return behavior_spec

8
ml-agents/mlagents/trainers/tests/simple_test_envs.py


ActionSpec,
BaseEnv,
BehaviorSpec,
SensorType,
DecisionSteps,
TerminalSteps,
BehaviorMapping,

)
else:
action_spec = ActionSpec.create_continuous(action_size)
self.behavior_spec = BehaviorSpec(self._make_obs_spec(), action_spec)
sensor_type_list = [
SensorType.OBSERVATION for i in range(len(self._make_obs_spec()))
]
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(), sensor_type_list, action_spec
)
self.action_size = action_size
self.names = brain_names
self.positions: Dict[str, List[float]] = {}

7
ml-agents/mlagents/trainers/tests/tensorflow/test_models.py


from mlagents.trainers.tf.models import ModelUtils
from mlagents.tf_utils import tf
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
obs_shapes = [(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector)
sensor_types = [SensorType.OBSERVATION for _ in range(len(obs_shapes))]
[(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector),
ActionSpec.create_discrete((1,)),
obs_shapes, sensor_types, ActionSpec.create_discrete((1,))
)
return behavior_spec

4
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, SensorType
dummy_groupspec = BehaviorSpec([(1,)], dummy_actionspec)
dummy_groupspec = BehaviorSpec([(1,)], [SensorType.OBSERVATION], dummy_actionspec)
return dummy_groupspec

45
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


CuriosityRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
from mlagents.trainers.settings import CuriositySettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec(
[(10,), (64, 66, 3), (84, 86, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec(
[(10,), (64, 66, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_TWODISCRETE,
),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec(
[(10,), (64, 66, 3), (24, 26, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS)]
"behavior_spec",
[BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS)],
)
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec(
[(10,), (64, 66, 3), (24, 26, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:

14
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py


ExtrinsicRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
],
)
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None:

32
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


GAILRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((20,))
@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
@pytest.mark.parametrize(
"behavior_spec",
[BehaviorSpec([(8,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS)],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = GAILRewardProvider(behavior_spec, gail_settings)

@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
@pytest.mark.parametrize(
"behavior_spec",
[BehaviorSpec([(8,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS)],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = create_reward_provider(

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec(
[(8,), (24, 26, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec([(50,)], [SensorType.OBSERVATION], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec(
[(8,), (24, 26, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec([(50,)], [SensorType.OBSERVATION], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

32
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py


RNDRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
from mlagents.trainers.settings import RNDSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_CONTINUOUS),
BehaviorSpec(
[(10,), (64, 66, 3), (84, 86, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec(
[(10,), (64, 66, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_TWODISCRETE,
),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
BehaviorSpec(
[(10,), (64, 66, 3), (24, 26, 1)],
[SensorType.OBSERVATION, SensorType.OBSERVATION, SensorType.OBSERVATION],
ACTIONSPEC_CONTINUOUS,
),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], [SensorType.OBSERVATION], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

7
protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto


PNG = 1;
}
enum SensorTypeProto {
OBSERVATION = 0;
GOAL = 1;
REWARD = 2;
}
message ObservationProto {
message FloatData {
repeated float data = 1;

FloatData float_data = 4;
}
repeated int32 compressed_channel_mapping = 5;
SensorTypeProto sensor_type = 6;
}

48
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs


using Unity.MLAgents.Sensors;
public class GoalSensorComponent : SensorComponent
{
public int observationSize;
public GoalSensor goalSensor;
/// <summary>
/// Creates a GoalSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
goalSensor = new GoalSensor(observationSize);
return goalSensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { observationSize };
}
public void AddGoal(float goal)
{
if (goalSensor != null)
{
goalSensor.AddObservation(goal);
}
}
}
public class GoalSensor : VectorSensor
{
public GoalSensor(int observationSize, string name = null) : base(observationSize)
{
if (name == null)
{
name = $"GoalSensor_size{observationSize}";
}
}
public override SensorType GetSensorType()
{
return SensorType.Goal;
}
}

11
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs.meta


fileFormatVersion: 2
guid: 163dac4bcbb2f4d8499db2cdcb22a89e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存