浏览代码

pass sensor name through to ObservationSpec (#5036)

/develop/gail-srl-hack
GitHub 3 年前
当前提交
46461986
共有 12 个文件被更改,包括 114 次插入32 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 5
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 48
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  4. 4
      ml-agents-envs/mlagents_envs/base_env.py
  5. 23
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  6. 7
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  7. 9
      ml-agents-envs/mlagents_envs/rpc_utils.py
  8. 9
      ml-agents/mlagents/trainers/tests/dummy_config.py
  9. 16
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  10. 16
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  11. 4
      ml-agents/tests/yamato/scripts/run_llapi.py
  12. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto

1
com.unity.ml-agents/CHANGELOG.md


#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The `encoding_size` setting for RewardSignals has been deprecated. Please use `network_settings` instead. (#4982)
- Sensor names are now passed through to `ObservationSpec.name`. (#5036)
### Bug Fixes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

5
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


}
}
observationProto.Shape.AddRange(shape);
var sensorName = sensor.GetName();
if (!string.IsNullOrEmpty(sensorName))
{
observationProto.Name = sensorName;
}
// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;

48
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKBAwoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKPAwoQT2JzZXJ2YXRp",
"b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
"ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
"dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",

"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvGhkKCUZsb2F0RGF0YRIMCgRkYXRhGAEgAygCQhIKEG9i",
"c2VydmF0aW9uX2RhdGEqKQoUQ29tcHJlc3Npb25UeXBlUHJvdG8SCAoETk9O",
"RRAAEgcKA1BORxABKkYKFE9ic2VydmF0aW9uVHlwZVByb3RvEgsKB0RFRkFV",
"TFQQABIICgRHT0FMEAESCgoGUkVXQVJEEAISCwoHTUVTU0FHRRADQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh",
"dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5",
"cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqRgoUT2JzZXJ2YXRpb25UeXBl",
"UHJvdG8SCwoHREVGQVVMVBAAEggKBEdPQUwQARIKCgZSRVdBUkQQAhILCgdN",
"RVNTQUdFEANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVj",
"dHNiBnByb3RvMw=="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType", "Name" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion

compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
dimensionProperties_ = other.dimensionProperties_.Clone();
observationType_ = other.observationType_;
name_ = other.name_;
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

}
}
/// <summary>Field number for the "name" field.</summary>
public const int NameFieldNumber = 8;
private string name_ = "";
/// <summary>
/// Optional name of the observation.
/// This will be set to the ISensor name when writing,
/// and read into the ObservationSpec in the low-level API
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Name {
get { return name_; }
set {
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {

if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
if (ObservationType != other.ObservationType) return false;
if (Name != other.Name) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= compressedChannelMapping_.GetHashCode();
hash ^= dimensionProperties_.GetHashCode();
if (ObservationType != 0) hash ^= ObservationType.GetHashCode();
if (Name.Length != 0) hash ^= Name.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

output.WriteRawTag(56);
output.WriteEnum((int) ObservationType);
}
if (Name.Length != 0) {
output.WriteRawTag(66);
output.WriteString(Name);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

if (ObservationType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ObservationType);
}
if (Name.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

dimensionProperties_.Add(other.dimensionProperties_);
if (other.ObservationType != 0) {
ObservationType = other.ObservationType;
}
if (other.Name.Length != 0) {
Name = other.Name;
}
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:

}
case 56: {
observationType_ = (global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto) input.ReadEnum();
break;
}
case 66: {
Name = input.ReadString();
break;
}
}

4
ml-agents-envs/mlagents_envs/base_env.py


dimension_property: Tuple[DimensionProperty, ...]
observation_type: ObservationType
# Optional name. For observations coming from com.unity.ml-agents, this
# will be the ISensor name.
name: str
class BehaviorSpec(NamedTuple):
"""

23
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x81\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

],
containing_type=None,
options=None,
serialized_start=466,
serialized_end=507,
serialized_start=480,
serialized_end=521,
)
_sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO)

],
containing_type=None,
options=None,
serialized_start=509,
serialized_end=579,
serialized_start=523,
serialized_end=593,
)
_sym_db.RegisterEnumDescriptor(_OBSERVATIONTYPEPROTO)

extension_ranges=[],
oneofs=[
],
serialized_start=419,
serialized_end=444,
serialized_start=433,
serialized_end=458,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='name', full_name='communicator_objects.ObservationProto.name', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

index=0, containing_type=None, fields=[]),
],
serialized_start=79,
serialized_end=464,
serialized_end=478,
)
_OBSERVATIONPROTO_FLOATDATA.containing_type = _OBSERVATIONPROTO

7
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


Iterable as typing___Iterable,
List as typing___List,
Optional as typing___Optional,
Text as typing___Text,
Tuple as typing___Tuple,
cast as typing___cast,
)

compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
dimension_properties = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
observation_type = ... # type: ObservationTypeProto
name = ... # type: typing___Text
@property
def float_data(self) -> ObservationProto.FloatData: ...

compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
dimension_properties : typing___Optional[typing___Iterable[builtin___int]] = None,
observation_type : typing___Optional[ObservationTypeProto] = None,
name : typing___Optional[typing___Text] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"observation_data",u"observation_type",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"name",u"observation_data",u"observation_type",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"observation_data",b"observation_data",u"observation_type",b"observation_type",u"shape",b"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"name",b"name",u"observation_data",b"observation_data",u"observation_type",b"observation_type",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...

9
ml-agents-envs/mlagents_envs/rpc_utils.py


for obs in agent_info.observations:
observation_specs.append(
ObservationSpec(
tuple(obs.shape),
tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
name=obs.name,
shape=tuple(obs.shape),
observation_type=ObservationType(obs.observation_type),
dimension_property=tuple(
DimensionProperty(dim) for dim in obs.dimension_properties
)
ObservationType(obs.observation_type),
)
)

9
ml-agents/mlagents/trainers/tests/dummy_config.py


shapes: List[Tuple[int, ...]]
) -> List[ObservationSpec]:
obs_specs: List[ObservationSpec] = []
for shape in shapes:
for i, shape in enumerate(shapes):
spec = ObservationSpec(shape, dim_prop, ObservationType.DEFAULT)
spec = ObservationSpec(
name=f"observation {i} with shape {shape}",
shape=shape,
dimension_property=dim_prop,
observation_type=ObservationType.DEFAULT,
)
obs_specs.append(spec)
return obs_specs

16
ml-agents/mlagents/trainers/tests/simple_test_envs.py


self.names = brain_names
self.positions: Dict[str, List[float]] = {}
self.step_count: Dict[str, float] = {}
self.random = random.Random(str(self.behavior_spec))
# Concatenate the arguments for a consistent random seed
seed = (
brain_names,
step_size,
num_visual,
num_vector,
num_var_len,
vis_obs_size,
vec_obs_size,
var_len_obs_size,
action_sizes,
)
self.random = random.Random(str(seed))
self.goal: Dict[str, int] = {}
self.action = {}
self.rewards: Dict[str, float] = {}

16
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


@pytest.mark.check_environment_trains
@pytest.mark.parametrize("num_visual", [1, 2])
def test_hybrid_visual_ppo(num_visual):
@pytest.mark.parametrize("num_visual,training_seed", [(1, 1336), (2, 1338)])
def test_hybrid_visual_ppo(num_visual, training_seed):
env = SimpleEnvironment(
[BRAIN_NAME], num_visual=num_visual, num_vector=0, action_sizes=(1, 1)
)

config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams)
check_environment_trains(env, {BRAIN_NAME: config}, training_seed=1336)
check_environment_trains(env, {BRAIN_NAME: config}, training_seed=training_seed)
@pytest.mark.check_environment_trains

config = attr.evolve(
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=2200
)
check_environment_trains(
env, {BRAIN_NAME: config}, success_threshold=0.9, training_seed=1336
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)
@pytest.mark.parametrize("num_visual", [1, 2])
def test_hybrid_visual_sac(num_visual):
@pytest.mark.parametrize("num_visual,training_seed", [(1, 1337), (2, 1338)])
def test_hybrid_visual_sac(num_visual, training_seed):
env = SimpleEnvironment(
[BRAIN_NAME], num_visual=num_visual, num_vector=0, action_sizes=(1, 1)
)

config = attr.evolve(
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=3000
)
check_environment_trains(env, {BRAIN_NAME: config})
check_environment_trains(env, {BRAIN_NAME: config}, training_seed=training_seed)
@pytest.mark.check_environment_trains

4
ml-agents/tests/yamato/scripts/run_llapi.py


# Examine the number of observations per Agent
print("Number of observations : ", len(group_spec.observation_specs))
for obs_spec in group_spec.observation_specs:
# Make sure the name was set in the ObservationSpec
assert bool(obs_spec.name) is True, f'obs_spec.name="{obs_spec.name}"'
# Is there a visual observation ?
vis_obs = any(
len(obs_spec.shape) == 3 for obs_spec in group_spec.observation_specs

4
protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto


repeated int32 compressed_channel_mapping = 5;
repeated int32 dimension_properties = 6;
ObservationTypeProto observation_type = 7;
// Optional name of the observation.
// This will be set to the ISensor name when writing,
// and read into the ObservationSpec in the low-level API
string name = 8;
}
正在加载...
取消
保存