浏览代码

Add SensorType field to SensorSpec

/MLA-1734-demo-provider
Arthur Juliani 4 年前
当前提交
0a22af55
共有 15 个文件被更改,包括 259 次插入64 次删除
  1. 3
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  2. 52
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  3. 9
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  4. 7
      docs/Python-API.md
  5. 11
      ml-agents-envs/mlagents_envs/base_env.py
  6. 56
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  7. 27
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  8. 19
      ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2.py
  9. 81
      ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2_grpc.py
  10. 5
      ml-agents-envs/mlagents_envs/rpc_utils.py
  11. 4
      ml-agents/mlagents/trainers/tests/dummy_config.py
  12. 2
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  13. 8
      protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
  14. 28
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
  15. 11
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta

3
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


}
}
observationProto.Shape.AddRange(shape);
var typeSensor = sensor as ITypedSensor;
observationProto.SensorType = (SensorTypeProto)typeSensor.GetSensorType();
return observationProto;
}

52
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK7AgoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL3AgoQT2JzZXJ2YXRp",
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUaGQoJRmxvYXREYXRh",
"EgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVz",
"c2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1M",
"QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUSOgoLc2Vuc29yX3R5",
"cGUYByABKA4yJS5jb21tdW5pY2F0b3Jfb2JqZWN0cy5TZW5zb3JUeXBlUHJv",
"dG8aGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f",
"ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H",
"EAEqRQoPU2Vuc29yVHlwZVByb3RvEg8KC09CU0VSVkFUSU9OEAASCAoER09B",
"TBABEgoKBlJFV0FSRBACEgsKB01FU1NBR0UQA0IlqgIiVW5pdHkuTUxBZ2Vu",
"dHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "SensorType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion

internal enum CompressionTypeProto {
[pbr::OriginalName("NONE")] None = 0,
[pbr::OriginalName("PNG")] Png = 1,
}
internal enum SensorTypeProto {
[pbr::OriginalName("OBSERVATION")] Observation = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
[pbr::OriginalName("MESSAGE")] Message = 3,
}
#endregion

compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
dimensionProperties_ = other.dimensionProperties_.Clone();
sensorType_ = other.sensorType_;
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

get { return dimensionProperties_; }
}
/// <summary>Field number for the "sensor_type" field.</summary>
public const int SensorTypeFieldNumber = 7;
private global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto sensorType_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto SensorType {
get { return sensorType_; }
set {
sensorType_ = value;
}
}
private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {

if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
if (SensorType != other.SensorType) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
hash ^= dimensionProperties_.GetHashCode();
if (SensorType != 0) hash ^= SensorType.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec);
if (SensorType != 0) {
output.WriteRawTag(56);
output.WriteEnum((int) SensorType);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec);
if (SensorType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) SensorType);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
dimensionProperties_.Add(other.dimensionProperties_);
if (other.SensorType != 0) {
SensorType = other.SensorType;
}
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

case 50:
case 48: {
dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec);
break;
}
case 56: {
sensorType_ = (global::Unity.MLAgents.CommunicatorObjects.SensorTypeProto) input.ReadEnum();
break;
}
}

9
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


/// <summary>
/// A sensor implementation for vector observations.
/// </summary>
public class VectorSensor : ISensor
public class VectorSensor : ISensor, ITypedSensor
{
// TODO use float[] instead
// TODO allow setting float[]

SensorType m_sensorType;
/// <summary>
/// Initializes the sensor.

m_Observations = new List<float>(observationSize);
m_Name = name;
m_sensorType = SensorType.Observation;
m_Shape = new[] { observationSize };
}

public int[] GetObservationShape()
{
return m_Shape;
}
public SensorType GetSensorType()
{
return m_sensorType;
}
/// <inheritdoc/>

7
docs/Python-API.md


corresponds to an observation's properties: `shape` is a tuple of ints that
corresponds to the shape of the observation (without the number of agents dimension).
`dimension_property` is a tuple of flags containing extra information about how the
data should be processed in the corresponding dimension. Note that the `SensorSpec`
have the same ordering as the ordering of observations in the DecisionSteps,
DecisionStep, TerminalSteps and TerminalStep.
data should be processed in the corresponding dimension. `type` is an enum
corresponding to what type of sensor is generating the data (i.e., observation, goal,
etc). Note that the `SensorSpec` have the same ordering as the ordering of observations
in the DecisionSteps, DecisionStep, TerminalSteps and TerminalStep.
- `action_spec` is an `ActionSpec` namedtuple that defines the number and types
of actions for the Agent.

11
ml-agents-envs/mlagents_envs/base_env.py


Any,
Mapping as MappingType,
)
from enum import IntFlag
from enum import IntFlag, Enum
import numpy as np
from mlagents_envs.exception import UnityActionException

VARIABLE_SIZE = 4
class SensorType(Enum):
OBSERVATION = 0
GOAL = 1
REWARD = 2
MESSAGE = 3
class SensorSpec(NamedTuple):
"""
A NamedTuple containing information about the observation of Agents.

dimension.
- type is an enum of SensorType.
type: SensorType
class BehaviorSpec(NamedTuple):

56
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xbb\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xf7\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12:\n\x0bsensor_type\x18\x07 \x01(\x0e\x32%.communicator_objects.SensorTypeProto\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*E\n\x0fSensorTypeProto\x12\x0f\n\x0bOBSERVATION\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

],
containing_type=None,
options=None,
serialized_start=396,
serialized_end=437,
serialized_start=456,
serialized_end=497,
_SENSORTYPEPROTO = _descriptor.EnumDescriptor(
name='SensorTypeProto',
full_name='communicator_objects.SensorTypeProto',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='OBSERVATION', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='GOAL', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='REWARD', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MESSAGE', index=3, number=3,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=499,
serialized_end=568,
)
_sym_db.RegisterEnumDescriptor(_SENSORTYPEPROTO)
SensorTypeProto = enum_type_wrapper.EnumTypeWrapper(_SENSORTYPEPROTO)
OBSERVATION = 0
GOAL = 1
REWARD = 2
MESSAGE = 3

extension_ranges=[],
oneofs=[
],
serialized_start=349,
serialized_end=374,
serialized_start=409,
serialized_end=434,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='sensor_type', full_name='communicator_objects.ObservationProto.sensor_type', index=6,
number=7, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

index=0, containing_type=None, fields=[]),
],
serialized_start=79,
serialized_end=394,
serialized_end=454,
_OBSERVATIONPROTO.fields_by_name['sensor_type'].enum_type = _SENSORTYPEPROTO
_OBSERVATIONPROTO.oneofs_by_name['observation_data'].fields.append(
_OBSERVATIONPROTO.fields_by_name['compressed_data'])
_OBSERVATIONPROTO.fields_by_name['compressed_data'].containing_oneof = _OBSERVATIONPROTO.oneofs_by_name['observation_data']

DESCRIPTOR.message_types_by_name['ObservationProto'] = _OBSERVATIONPROTO
DESCRIPTOR.enum_types_by_name['CompressionTypeProto'] = _COMPRESSIONTYPEPROTO
DESCRIPTOR.enum_types_by_name['SensorTypeProto'] = _SENSORTYPEPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
ObservationProto = _reflection.GeneratedProtocolMessageType('ObservationProto', (_message.Message,), dict(

27
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


NONE = typing___cast('CompressionTypeProto', 0)
PNG = typing___cast('CompressionTypeProto', 1)
class SensorTypeProto(builtin___int):
DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ...
@classmethod
def Name(cls, number: builtin___int) -> builtin___str: ...
@classmethod
def Value(cls, name: builtin___str) -> 'SensorTypeProto': ...
@classmethod
def keys(cls) -> typing___List[builtin___str]: ...
@classmethod
def values(cls) -> typing___List['SensorTypeProto']: ...
@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'SensorTypeProto']]: ...
OBSERVATION = typing___cast('SensorTypeProto', 0)
GOAL = typing___cast('SensorTypeProto', 1)
REWARD = typing___cast('SensorTypeProto', 2)
MESSAGE = typing___cast('SensorTypeProto', 3)
OBSERVATION = typing___cast('SensorTypeProto', 0)
GOAL = typing___cast('SensorTypeProto', 1)
REWARD = typing___cast('SensorTypeProto', 2)
MESSAGE = typing___cast('SensorTypeProto', 3)
class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class FloatData(google___protobuf___message___Message):

compressed_data = ... # type: builtin___bytes
compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
dimension_properties = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
sensor_type = ... # type: SensorTypeProto
@property
def float_data(self) -> ObservationProto.FloatData: ...

float_data : typing___Optional[ObservationProto.FloatData] = None,
compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
dimension_properties : typing___Optional[typing___Iterable[builtin___int]] = None,
sensor_type : typing___Optional[SensorTypeProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"observation_data",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"observation_data",u"sensor_type",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"observation_data",b"observation_data",u"sensor_type",b"sensor_type",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...

19
ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2.py


# -*- coding: utf-8 -*-
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
"""Generated protocol buffer code."""
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()

name='mlagents_envs/communicator_objects/unity_to_external.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n:mlagents_envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents_envs/communicator_objects/unity_message.proto2v\n\x14UnityToExternalProto\x12^\n\x08\x45xchange\x12\'.communicator_objects.UnityMessageProto\x1a\'.communicator_objects.UnityMessageProto\"\x00\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_options=b'\252\002\"Unity.MLAgents.CommunicatorObjects',
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n:mlagents_envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents_envs/communicator_objects/unity_message.proto2v\n\x14UnityToExternalProto\x12^\n\x08\x45xchange\x12\'.communicator_objects.UnityMessageProto\x1a\'.communicator_objects.UnityMessageProto\"\x00\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3'
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.DESCRIPTOR,])

DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\"Unity.MLAgents.CommunicatorObjects'))
DESCRIPTOR._options = None
_UNITYTOEXTERNALPROTO = _descriptor.ServiceDescriptor(
name='UnityToExternalProto',

options=None,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=140,
serialized_end=258,
methods=[

containing_service=None,
input_type=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGEPROTO,
output_type=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGEPROTO,
options=None,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNALPROTO)

81
ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2_grpc.py


# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from mlagents_envs.communicator_objects import unity_message_pb2 as mlagents__envs_dot_communicator__objects_dot_unity__message__pb2

# missing associated documentation comment in .proto file
pass
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Exchange = channel.unary_unary(
'/communicator_objects.UnityToExternalProto/Exchange',
request_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
response_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
)
Args:
channel: A grpc.Channel.
"""
self.Exchange = channel.unary_unary(
'/communicator_objects.UnityToExternalProto/Exchange',
request_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
response_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
)
# missing associated documentation comment in .proto file
pass
"""Missing associated documentation comment in .proto file."""
def Exchange(self, request, context):
"""Sends the academy parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Exchange(self, request, context):
"""Sends the academy parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
rpc_method_handlers = {
'Exchange': grpc.unary_unary_rpc_method_handler(
servicer.Exchange,
request_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
response_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'communicator_objects.UnityToExternalProto', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
rpc_method_handlers = {
'Exchange': grpc.unary_unary_rpc_method_handler(
servicer.Exchange,
request_deserializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
response_serializer=mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'communicator_objects.UnityToExternalProto', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class UnityToExternalProto(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Exchange(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/communicator_objects.UnityToExternalProto/Exchange',
mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.SerializeToString,
mlagents__envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessageProto.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

5
ml-agents-envs/mlagents_envs/rpc_utils.py


tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
for obs in agent_info.observations
]
sensor_types = [obs.sensor_type for obs in agent_info.observations]
SensorSpec(obs_shape, dim_p)
for obs_shape, dim_p in zip(observation_shape, dim_props)
SensorSpec(obs_shape, dim_p, sensor_type)
for obs_shape, dim_p, sensor_type in zip(observation_shape, dim_props, sensor_types)
]
# proto from communicator < v1.3 does not set action spec, use deprecated fields instead
if (

4
ml-agents/mlagents/trainers/tests/dummy_config.py


from typing import List, Tuple
from mlagents_envs.base_env import SensorSpec, DimensionProperty
from mlagents_envs.base_env import SensorSpec, DimensionProperty, SensorType
import pytest
import copy
import os

sen_spec: List[SensorSpec] = []
for shape in shapes:
dim_prop = (DimensionProperty.UNSPECIFIED,) * len(shape)
spec = SensorSpec(shape, dim_prop)
spec = SensorSpec(shape, dim_prop, SensorType.OBSERVATION)
sen_spec.append(spec)
return sen_spec

2
ml-agents/mlagents/trainers/tests/simple_test_envs.py


self.action[name] = None
self.step_result[name] = None
def _make_sensor_specs(self) -> SensorSpec:
def _make_sensor_specs(self) -> List[SensorSpec]:
obs_shape: List[Any] = []
for _ in range(self.num_vector):
obs_shape.append((self.vec_obs_size,))

8
protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto


PNG = 1;
}
enum SensorTypeProto {
OBSERVATION = 0;
GOAL = 1;
REWARD = 2;
MESSAGE = 3;
}
message ObservationProto {
message FloatData {
repeated float data = 1;

}
repeated int32 compressed_channel_mapping = 5;
repeated int32 dimension_properties = 6;
SensorTypeProto sensor_type = 7;
}

28
com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs


namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The SensorType flag of the observation
/// </summary>
[System.Flags]
public enum SensorType
{
Observation = 0,
Goal = 1,
Reward = 2,
Message = 3,
}
/// <summary>
/// Sensor interface for sensors with variable types.
/// </summary>
public interface ITypedSensor
{
/// <summary>
/// Returns the SensorType enum corresponding to the type of the sensor.
/// </summary>
/// <returns>The SensorType enum</returns>
SensorType GetSensorType();
}
}

11
com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta


fileFormatVersion: 2
guid: 3751edac8122c411dbaef8f1b7043b82
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存