浏览代码

Add hybrid action capability flag (#4576)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
4cfae52e
共有 8 个文件被更改,包括 69 次插入14 次删除
  1. 6
      com.unity.ml-agents/Runtime/Academy.cs
  2. 2
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 5
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  4. 44
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  5. 13
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  6. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  7. 4
      ml-agents-envs/mlagents_envs/environment.py
  8. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto

6
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
/// <item>
/// <term>1.3.0</term>
/// <description>Support hybrid action spaces.</description>
/// </item>
const string k_ApiVersion = "1.2.0";
const string k_ApiVersion = "1.3.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

2
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
};
}

BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
};
}

5
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true,
bool compressedChannelMapping = true, bool hybridActions = true)
HybridActions = hybridActions;
}
/// <summary>

44
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
"cm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
}));
}
#endregion

baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "hybridActions" field.</summary>
public const int HybridActionsFieldNumber = 4;
private bool hybridActions_;
/// <summary>
/// support for mixed (discrete + continuous) actions
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HybridActions {
get { return hybridActions_; }
set {
hybridActions_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (HybridActions != false) {
output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (HybridActions != false) {
size += 1 + 1;
}
if (_unknownFields != null) {

if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
case 32: {
HybridActions = input.ReadBool();
break;
}
}

13
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"}\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\x94\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='hybridActions', full_name='communicator_objects.UnityRLCapabilitiesProto.hybridActions', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

extension_ranges=[],
oneofs=[
],
serialized_start=79,
serialized_end=204,
serialized_start=80,
serialized_end=228,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


baseRLCapabilities = ... # type: builtin___bool
concatenatedPngObservations = ... # type: builtin___bool
compressedChannelMapping = ... # type: builtin___bool
hybridActions = ... # type: builtin___bool
def __init__(self,
*,

hybridActions : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions"]) -> None: ...

4
ml-agents-envs/mlagents_envs/environment.py


# * 1.0.0 - initial version
# * 1.1.0 - support concatenated PNGs for compressed observations.
# * 1.2.0 - support compression mapping for stacked compressed observations.
API_VERSION = "1.2.0"
# * 1.3.0 - support hybrid action spaces.
API_VERSION = "1.3.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.

capabilities.baseRLCapabilities = True
capabilities.concatenatedPngObservations = True
capabilities.compressedChannelMapping = True
capabilities.hybridActions = True
return capabilities
@staticmethod

3
protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto


// compression mapping for stacking compressed observations.
bool compressedChannelMapping = 3;
// support for hybrid action spaces (discrete + continuous)
bool hybridActions = 4;
}
正在加载...
取消
保存