浏览代码

Add capabilities checks bewteen C# and Python codebases. (#3831)

/develop/dockerfile
GitHub 5 年前
当前提交
85789ded
共有 23 个文件被更改,包括 604 次插入28 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 9
      com.unity.ml-agents/Runtime/Academy.cs
  3. 17
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  4. 10
      com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
  5. 3
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  6. 52
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs
  7. 58
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs
  8. 18
      ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py
  9. 14
      ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi
  10. 17
      ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_output_pb2.py
  11. 14
      ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_output_pb2.pyi
  12. 28
      ml-agents-envs/mlagents_envs/environment.py
  13. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto
  14. 4
      protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_output.proto
  15. 36
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  16. 3
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs.meta
  17. 182
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  18. 11
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs.meta
  19. 23
      com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs
  20. 3
      com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs.meta
  21. 71
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  22. 41
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  23. 13
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto

1
com.unity.ml-agents/CHANGELOG.md


- `num_updates` and `train_interval` for SAC were replaced with `steps_per_update`. (#3690)
- `WriteAdapter` was renamed to `ObservationWriter`. If you have a custom `ISensor` implementation,
you will need to change the signature of its `Write()` method. (#3834)
- `UnityRLCapabilities` was added to help inform users when RL features are mismatched between C# and Python packages. (#3831)
### Bug Fixes

9
com.unity.ml-agents/Runtime/Academy.cs


set { m_InferenceSeed = value; }
}
/// <summary>
/// Returns the RLCapabilities of the python client that the unity process is connected to.
/// </summary>
internal UnityRLCapabilities TrainerCapabilities { get; set; }
// The Academy uses a series of events to communicate with agents
// to facilitate synchronization. More specifically, it ensures
// that all the agents perform their steps in a consistent order (i.e. no

unityCommunicationVersion = k_ApiVersion,
unityPackageVersion = k_PackageVersion,
name = "AcademySingleton",
CSharpCapabilities = new UnityRLCapabilities()
TrainerCapabilities = unityRlInitParameters.TrainerCapabilities;
TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities();
}
catch
{

17
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


seed = inputProto.Seed,
pythonLibraryVersion = inputProto.PackageVersion,
pythonCommunicationVersion = inputProto.CommunicationVersion,
TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities()
};
}

return observationProto;
}
#endregion
public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto)
{
return new UnityRLCapabilities
{
m_BaseRLCapabilities = proto.BaseRLCapabilities
};
}
public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
{
return new UnityRLCapabilitiesProto
{
BaseRLCapabilities = rlCaps.m_BaseRLCapabilities
};
}
}
}

10
com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs


/// The version of the communication API.
/// </summary>
public string unityCommunicationVersion;
/// <summary>
/// The RL capabilities of the C# codebase.
/// </summary>
public UnityRLCapabilities CSharpCapabilities;
}
internal struct UnityRLInitParameters
{

/// The version of the communication API that python is using.
/// </summary>
public string pythonCommunicationVersion;
/// <summary>
/// The RL capabilities of the Trainer codebase.
/// </summary>
public UnityRLCapabilities TrainerCapabilities;
}
internal struct UnityRLInputParameters
{

3
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


{
Name = initParameters.name,
PackageVersion = initParameters.unityPackageVersion,
CommunicationVersion = initParameters.unityCommunicationVersion
CommunicationVersion = initParameters.unityCommunicationVersion,
Capabilities = initParameters.CSharpCapabilities.ToProto()
};
UnityInputProto input;

52
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs


string.Concat(
"CkZtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js",
"X2luaXRpYWxpemF0aW9uX2lucHV0LnByb3RvEhRjb21tdW5pY2F0b3Jfb2Jq",
"ZWN0cyJnCh9Vbml0eVJMSW5pdGlhbGl6YXRpb25JbnB1dFByb3RvEgwKBHNl",
"ZWQYASABKAUSHQoVY29tbXVuaWNhdGlvbl92ZXJzaW9uGAIgASgJEhcKD3Bh",
"Y2thZ2VfdmVyc2lvbhgDIAEoCUIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9y",
"T2JqZWN0c2IGcHJvdG8z"));
"ZWN0cxo1bWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9jYXBh",
"YmlsaXRpZXMucHJvdG8irQEKH1VuaXR5UkxJbml0aWFsaXphdGlvbklucHV0",
"UHJvdG8SDAoEc2VlZBgBIAEoBRIdChVjb21tdW5pY2F0aW9uX3ZlcnNpb24Y",
"AiABKAkSFwoPcGFja2FnZV92ZXJzaW9uGAMgASgJEkQKDGNhcGFiaWxpdGll",
"cxgEIAEoCzIuLmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxDYXBhYmls",
"aXRpZXNQcm90b0IfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG",
"cHJvdG8z"));
new pbr::FileDescriptor[] { },
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationInputProto), global::MLAgents.CommunicatorObjects.UnityRLInitializationInputProto.Parser, new[]{ "Seed", "CommunicationVersion", "PackageVersion" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationInputProto), global::MLAgents.CommunicatorObjects.UnityRLInitializationInputProto.Parser, new[]{ "Seed", "CommunicationVersion", "PackageVersion", "Capabilities" }, null, null, null)
}));
}
#endregion

seed_ = other.seed_;
communicationVersion_ = other.communicationVersion_;
packageVersion_ = other.packageVersion_;
Capabilities = other.capabilities_ != null ? other.Capabilities.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "capabilities" field.</summary>
public const int CapabilitiesFieldNumber = 4;
private global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto capabilities_;
/// <summary>
/// The RL Capabilities of the Python trainer.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto Capabilities {
get { return capabilities_; }
set {
capabilities_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLInitializationInputProto);

if (Seed != other.Seed) return false;
if (CommunicationVersion != other.CommunicationVersion) return false;
if (PackageVersion != other.PackageVersion) return false;
if (!object.Equals(Capabilities, other.Capabilities)) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (Seed != 0) hash ^= Seed.GetHashCode();
if (CommunicationVersion.Length != 0) hash ^= CommunicationVersion.GetHashCode();
if (PackageVersion.Length != 0) hash ^= PackageVersion.GetHashCode();
if (capabilities_ != null) hash ^= Capabilities.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(26);
output.WriteString(PackageVersion);
}
if (capabilities_ != null) {
output.WriteRawTag(34);
output.WriteMessage(Capabilities);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
if (PackageVersion.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(PackageVersion);
}
if (capabilities_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Capabilities);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

if (other.PackageVersion.Length != 0) {
PackageVersion = other.PackageVersion;
}
if (other.capabilities_ != null) {
if (capabilities_ == null) {
capabilities_ = new global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto();
}
Capabilities.MergeFrom(other.Capabilities);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 26: {
PackageVersion = input.ReadString();
break;
}
case 34: {
if (capabilities_ == null) {
capabilities_ = new global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto();
}
input.ReadMessage(capabilities_);
break;
}
}

58
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs


string.Concat(
"CkdtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js",
"X2luaXRpYWxpemF0aW9uX291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29i",
"amVjdHMaOW1sYWdlbnRzX2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYnJh",
"aW5fcGFyYW1ldGVycy5wcm90byLGAQogVW5pdHlSTEluaXRpYWxpemF0aW9u",
"T3V0cHV0UHJvdG8SDAoEbmFtZRgBIAEoCRIdChVjb21tdW5pY2F0aW9uX3Zl",
"cnNpb24YAiABKAkSEAoIbG9nX3BhdGgYAyABKAkSRAoQYnJhaW5fcGFyYW1l",
"dGVycxgFIAMoCzIqLmNvbW11bmljYXRvcl9vYmplY3RzLkJyYWluUGFyYW1l",
"dGVyc1Byb3RvEhcKD3BhY2thZ2VfdmVyc2lvbhgHIAEoCUoECAYQB0IfqgIc",
"TUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"amVjdHMaNW1sYWdlbnRzX2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY2Fw",
"YWJpbGl0aWVzLnByb3RvGjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9v",
"YmplY3RzL2JyYWluX3BhcmFtZXRlcnMucHJvdG8ijAIKIFVuaXR5UkxJbml0",
"aWFsaXphdGlvbk91dHB1dFByb3RvEgwKBG5hbWUYASABKAkSHQoVY29tbXVu",
"aWNhdGlvbl92ZXJzaW9uGAIgASgJEhAKCGxvZ19wYXRoGAMgASgJEkQKEGJy",
"YWluX3BhcmFtZXRlcnMYBSADKAsyKi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5C",
"cmFpblBhcmFtZXRlcnNQcm90bxIXCg9wYWNrYWdlX3ZlcnNpb24YByABKAkS",
"RAoMY2FwYWJpbGl0aWVzGAggASgLMi4uY29tbXVuaWNhdG9yX29iamVjdHMu",
"VW5pdHlSTENhcGFiaWxpdGllc1Byb3RvSgQIBhAHQh+qAhxNTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor, },
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor, global::MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto), global::MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto.Parser, new[]{ "Name", "CommunicationVersion", "LogPath", "BrainParameters", "PackageVersion" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto), global::MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto.Parser, new[]{ "Name", "CommunicationVersion", "LogPath", "BrainParameters", "PackageVersion", "Capabilities" }, null, null, null)
}));
}
#endregion

logPath_ = other.logPath_;
brainParameters_ = other.brainParameters_.Clone();
packageVersion_ = other.packageVersion_;
Capabilities = other.capabilities_ != null ? other.Capabilities.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "capabilities" field.</summary>
public const int CapabilitiesFieldNumber = 8;
private global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto capabilities_;
/// <summary>
/// The RL Capabilities of the C# package.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto Capabilities {
get { return capabilities_; }
set {
capabilities_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLInitializationOutputProto);

if (LogPath != other.LogPath) return false;
if(!brainParameters_.Equals(other.brainParameters_)) return false;
if (PackageVersion != other.PackageVersion) return false;
if (!object.Equals(Capabilities, other.Capabilities)) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (LogPath.Length != 0) hash ^= LogPath.GetHashCode();
hash ^= brainParameters_.GetHashCode();
if (PackageVersion.Length != 0) hash ^= PackageVersion.GetHashCode();
if (capabilities_ != null) hash ^= Capabilities.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(58);
output.WriteString(PackageVersion);
}
if (capabilities_ != null) {
output.WriteRawTag(66);
output.WriteMessage(Capabilities);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

if (PackageVersion.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(PackageVersion);
}
if (capabilities_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Capabilities);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

brainParameters_.Add(other.brainParameters_);
if (other.PackageVersion.Length != 0) {
PackageVersion = other.PackageVersion;
}
if (other.capabilities_ != null) {
if (capabilities_ == null) {
capabilities_ = new global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto();
}
Capabilities.MergeFrom(other.Capabilities);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 58: {
PackageVersion = input.ReadString();
break;
}
case 66: {
if (capabilities_ == null) {
capabilities_ = new global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto();
}
input.ReadMessage(capabilities_);
break;
}
}

18
ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py


_sym_db = _symbol_database.Default()
from mlagents_envs.communicator_objects import capabilities_pb2 as mlagents__envs_dot_communicator__objects_dot_capabilities__pb2
DESCRIPTOR = _descriptor.FileDescriptor(

serialized_pb=_b('\nFmlagents_envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"g\n\x1fUnityRLInitializationInputProto\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x17\n\x0fpackage_version\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
)
serialized_pb=_b('\nFmlagents_envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents_envs/communicator_objects/capabilities.proto\"\xad\x01\n\x1fUnityRLInitializationInputProto\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x17\n\x0fpackage_version\x18\x03 \x01(\t\x12\x44\n\x0c\x63\x61pabilities\x18\x04 \x01(\x0b\x32..communicator_objects.UnityRLCapabilitiesProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_capabilities__pb2.DESCRIPTOR,])

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='capabilities', full_name='communicator_objects.UnityRLInitializationInputProto.capabilities', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

extension_ranges=[],
oneofs=[
],
serialized_start=96,
serialized_end=199,
serialized_start=152,
serialized_end=325,
_UNITYRLINITIALIZATIONINPUTPROTO.fields_by_name['capabilities'].message_type = mlagents__envs_dot_communicator__objects_dot_capabilities__pb2._UNITYRLCAPABILITIESPROTO
DESCRIPTOR.message_types_by_name['UnityRLInitializationInputProto'] = _UNITYRLINITIALIZATIONINPUTPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

14
ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi


Message as google___protobuf___message___Message,
)
from mlagents_envs.communicator_objects.capabilities_pb2 import (
UnityRLCapabilitiesProto as mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto,
)
from typing import (
Optional as typing___Optional,
Text as typing___Text,

seed = ... # type: builtin___int
communication_version = ... # type: typing___Text
package_version = ... # type: typing___Text
@property
def capabilities(self) -> mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto: ...
def __init__(self,
*,

capabilities : typing___Optional[mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLInitializationInputProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"communication_version",u"package_version",u"seed"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"capabilities"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",u"communication_version",u"package_version",u"seed"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"communication_version",b"communication_version",u"package_version",b"package_version",u"seed",b"seed"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities",u"communication_version",b"communication_version",u"package_version",b"package_version",u"seed",b"seed"]) -> None: ...

17
ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_output_pb2.py


_sym_db = _symbol_database.Default()
from mlagents_envs.communicator_objects import capabilities_pb2 as mlagents__envs_dot_communicator__objects_dot_capabilities__pb2
from mlagents_envs.communicator_objects import brain_parameters_pb2 as mlagents__envs_dot_communicator__objects_dot_brain__parameters__pb2

syntax='proto3',
serialized_pb=_b('\nGmlagents_envs/communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents_envs/communicator_objects/brain_parameters.proto\"\xc6\x01\n UnityRLInitializationOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12\x17\n\x0fpackage_version\x18\x07 \x01(\tJ\x04\x08\x06\x10\x07\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\nGmlagents_envs/communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents_envs/communicator_objects/capabilities.proto\x1a\x39mlagents_envs/communicator_objects/brain_parameters.proto\"\x8c\x02\n UnityRLInitializationOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12\x17\n\x0fpackage_version\x18\x07 \x01(\t\x12\x44\n\x0c\x63\x61pabilities\x18\x08 \x01(\x0b\x32..communicator_objects.UnityRLCapabilitiesProtoJ\x04\x08\x06\x10\x07\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
dependencies=[mlagents__envs_dot_communicator__objects_dot_brain__parameters__pb2.DESCRIPTOR,])
dependencies=[mlagents__envs_dot_communicator__objects_dot_capabilities__pb2.DESCRIPTOR,mlagents__envs_dot_communicator__objects_dot_brain__parameters__pb2.DESCRIPTOR,])

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='capabilities', full_name='communicator_objects.UnityRLInitializationOutputProto.capabilities', index=5,
number=8, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

extension_ranges=[],
oneofs=[
],
serialized_start=157,
serialized_end=355,
serialized_start=212,
serialized_end=480,
_UNITYRLINITIALIZATIONOUTPUTPROTO.fields_by_name['capabilities'].message_type = mlagents__envs_dot_communicator__objects_dot_capabilities__pb2._UNITYRLCAPABILITIESPROTO
DESCRIPTOR.message_types_by_name['UnityRLInitializationOutputProto'] = _UNITYRLINITIALIZATIONOUTPUTPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

14
ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_output_pb2.pyi


BrainParametersProto as mlagents_envs___communicator_objects___brain_parameters_pb2___BrainParametersProto,
)
from mlagents_envs.communicator_objects.capabilities_pb2 import (
UnityRLCapabilitiesProto as mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto,
)
from typing import (
Iterable as typing___Iterable,
Optional as typing___Optional,

@property
def brain_parameters(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___brain_parameters_pb2___BrainParametersProto]: ...
@property
def capabilities(self) -> mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto: ...
def __init__(self,
*,
name : typing___Optional[typing___Text] = None,

package_version : typing___Optional[typing___Text] = None,
capabilities : typing___Optional[mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLInitializationOutputProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"brain_parameters",u"communication_version",u"log_path",u"name",u"package_version"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"capabilities"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"brain_parameters",u"capabilities",u"communication_version",u"log_path",u"name",u"package_version"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"brain_parameters",b"brain_parameters",u"communication_version",b"communication_version",u"log_path",b"log_path",u"name",b"name",u"package_version",b"package_version"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"brain_parameters",b"brain_parameters",u"capabilities",b"capabilities",u"communication_version",b"communication_version",u"log_path",b"log_path",u"name",b"name",u"package_version",b"package_version"]) -> None: ...

28
ml-agents-envs/mlagents_envs/environment.py


from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto
from mlagents_envs.communicator_objects.agent_action_pb2 import AgentActionProto
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from mlagents_envs.communicator_objects.capabilities_pb2 import UnityRLCapabilitiesProto
from mlagents_envs.communicator_objects.unity_rl_initialization_input_pb2 import (
UnityRLInitializationInputProto,
)

)
return True
@staticmethod
def get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities = UnityRLCapabilitiesProto()
capabilities.baseRLCapabilities = True
return capabilities
@staticmethod
def warn_csharp_base_capabitlities(
caps: UnityRLCapabilitiesProto, unity_package_ver: str, python_package_ver: str
) -> None:
if not caps.baseRLCapabilities:
logger.warning(
"WARNING: The Unity process is not running with the expected base Reinforcement Learning"
" capabilities. Please be sure upgrade the Unity Package to a version that is compatible with this "
"python package.\n"
f"Python package version: {python_package_ver}, C# package version: {unity_package_ver}"
f"Please find the versions that work best together from our release page.\n"
"https://github.com/Unity-Technologies/ml-agents/releases"
)
def __init__(
self,
file_name: Optional[str] = None,

seed=seed,
communication_version=self.API_VERSION,
package_version=mlagents_envs.__version__,
capabilities=UnityEnvironment.get_capabilities_proto(),
)
try:
aca_output = self.send_academy_parameters(rl_init_parameters_in)

):
self._close(0)
UnityEnvironment._raise_version_exception(aca_params.communication_version)
UnityEnvironment.warn_csharp_base_capabitlities(
aca_params.capabilities,
aca_params.package_version,
UnityEnvironment.API_VERSION,
)
self._env_state: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {}
self._env_specs: Dict[str, BehaviorSpec] = {}

4
protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto


syntax = "proto3";
import "mlagents_envs/communicator_objects/capabilities.proto";
option csharp_namespace = "MLAgents.CommunicatorObjects";
package communicator_objects;

// Package/library version that the initiating side (typically the Python trainer) is using.
string package_version = 3;
// The RL Capabilities of the Python trainer.
UnityRLCapabilitiesProto capabilities = 4;
}

4
protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_output.proto


syntax = "proto3";
import "mlagents_envs/communicator_objects/capabilities.proto";
import "mlagents_envs/communicator_objects/brain_parameters.proto";
option csharp_namespace = "MLAgents.CommunicatorObjects";

// Package/library version that the responding side (typically the C# environment) is using.
string package_version = 7;
// The RL Capabilities of the C# package.
UnityRLCapabilitiesProto capabilities = 8;
}

36
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


using UnityEngine;
namespace MLAgents
{
internal class UnityRLCapabilities
{
internal bool m_BaseRLCapabilities;
/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
/// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched.
/// </summary>
public UnityRLCapabilities(bool baseRlCapabilities=true)
{
m_BaseRLCapabilities = baseRlCapabilities;
}
/// <summary>
/// Will print a warning to the console if Python does not support base capabilities and will
/// return <value>true</value> if the warning was printed.
/// </summary>
/// <returns></returns>
public bool WarnOnPythonMissingBaseRLCapabilities()
{
if (m_BaseRLCapabilities)
{
return false;
}
Debug.LogWarning("Unity has connected to a Training process that does not support" +
"Base Reinforcement Learning Capabilities. Please make sure you have the" +
" latest training codebase installed for this version of the ML-Agents package.");
return true;
}
}
}

3
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs.meta


fileFormatVersion: 2
guid: f95d271af72d4b75aa94d308222f79d8
timeCreated: 1587670989

182
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: mlagents_envs/communicator_objects/capabilities.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code
using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace MLAgents.CommunicatorObjects {
/// <summary>Holder for reflection information generated from mlagents_envs/communicator_objects/capabilities.proto</summary>
internal static partial class CapabilitiesReflection {
#region Descriptor
/// <summary>File descriptor for mlagents_envs/communicator_objects/capabilities.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;
static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiNgoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCEIf",
"qgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities" }, null, null, null)
}));
}
#endregion
}
#region Messages
/// <summary>
///
/// A Capabilities message that will communicate both C# and Python
/// what features are available to both.
/// </summary>
internal sealed partial class UnityRLCapabilitiesProto : pb::IMessage<UnityRLCapabilitiesProto> {
private static readonly pb::MessageParser<UnityRLCapabilitiesProto> _parser = new pb::MessageParser<UnityRLCapabilitiesProto>(() => new UnityRLCapabilitiesProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<UnityRLCapabilitiesProto> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto Clone() {
return new UnityRLCapabilitiesProto(this);
}
/// <summary>Field number for the "baseRLCapabilities" field.</summary>
public const int BaseRLCapabilitiesFieldNumber = 1;
private bool baseRLCapabilities_;
/// <summary>
/// These are the 1.0 capabilities.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool BaseRLCapabilities {
get { return baseRLCapabilities_; }
set {
baseRLCapabilities_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(UnityRLCapabilitiesProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
return Equals(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (BaseRLCapabilities != false) {
output.WriteRawTag(8);
output.WriteBool(BaseRLCapabilities);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (BaseRLCapabilities != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other == null) {
return;
}
if (other.BaseRLCapabilities != false) {
BaseRLCapabilities = other.BaseRLCapabilities;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
BaseRLCapabilities = input.ReadBool();
break;
}
}
}
}
}
#endregion
}
#endregion Designer generated code

11
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs.meta


fileFormatVersion: 2
guid: e8388443b440343299cab2e88988e14e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

23
com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs


using System.Text.RegularExpressions;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
namespace MLAgents.Tests.Communicator
{
[TestFixture]
public class UnityRLCapabilitiesTests
{
[Test]
public void TestWarnOnPythonMissingBaseRLCapabilities()
{
var caps = new UnityRLCapabilities();
Assert.False(caps.WarnOnPythonMissingBaseRLCapabilities());
LogAssert.NoUnexpectedReceived();
caps = new UnityRLCapabilities(false);
Assert.True(caps.WarnOnPythonMissingBaseRLCapabilities());
LogAssert.Expect(LogType.Warning, new Regex(".+"));
}
}
}

3
com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs.meta


fileFormatVersion: 2
guid: e6a3e82911b84029a446dcfd2d8af520
timeCreated: 1587695055

71
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mlagents_envs/communicator_objects/capabilities.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"6\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
)
_UNITYRLCAPABILITIESPROTO = _descriptor.Descriptor(
name='UnityRLCapabilitiesProto',
full_name='communicator_objects.UnityRLCapabilitiesProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='baseRLCapabilities', full_name='communicator_objects.UnityRLCapabilitiesProto.baseRLCapabilities', index=0,
number=1, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=79,
serialized_end=133,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
UnityRLCapabilitiesProto = _reflection.GeneratedProtocolMessageType('UnityRLCapabilitiesProto', (_message.Message,), dict(
DESCRIPTOR = _UNITYRLCAPABILITIESPROTO,
__module__ = 'mlagents_envs.communicator_objects.capabilities_pb2'
# @@protoc_insertion_point(class_scope:communicator_objects.UnityRLCapabilitiesProto)
))
_sym_db.RegisterMessage(UnityRLCapabilitiesProto)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
# @@protoc_insertion_point(module_scope)

41
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


# @generated by generate_proto_mypy_stubs.py. Do not edit!
import sys
from google.protobuf.descriptor import (
Descriptor as google___protobuf___descriptor___Descriptor,
)
from google.protobuf.message import (
Message as google___protobuf___message___Message,
)
from typing import (
Optional as typing___Optional,
)
from typing_extensions import (
Literal as typing_extensions___Literal,
)
builtin___bool = bool
builtin___bytes = bytes
builtin___float = float
builtin___int = int
class UnityRLCapabilitiesProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
baseRLCapabilities = ... # type: builtin___bool
def __init__(self,
*,
baseRLCapabilities : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities"]) -> None: ...

13
protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto


syntax = "proto3";
option csharp_namespace="MLAgents.CommunicatorObjects";
package communicator_objects;
/*
* A Capabilities message that will communicate both C# and Python
* what features are available to both.
*/
message UnityRLCapabilitiesProto {
// These are the 1.0 capabilities.
bool baseRLCapabilities = 1;
}
正在加载...
取消
保存