浏览代码

Add multiAgentGroup capabilities flag (#5096)

* Add multiAgentGroup capabilities flag

* Add proto

* Fix compiler error

* Add warning for multiagent group

* Add comment

* Fix spelling mistake
/develop/input-actuator-tanks
GitHub 4 年前
当前提交
85369e5b
共有 8 个文件被更改,包括 82 次插入11 次删除
  1. 2
      com.unity.ml-agents/Runtime/Academy.cs
  2. 23
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  3. 5
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  4. 40
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  5. 11
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  6. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  7. 3
      ml-agents-envs/mlagents_envs/environment.py
  8. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto

2
com.unity.ml-agents/Runtime/Academy.cs


/// </item>
/// <item>
/// <term>1.5.0</term>
/// <description>Support variable length observation training.</description>
/// <description>Support variable length observation training and multi-agent groups.</description>
/// </item>
/// </list>
/// </remarks>

23
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


{
#region AgentInfo
/// <summary>
/// Static flag to make sure that we only fire the warning once.
/// </summary>
private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup = false;
/// <summary>
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
/// </summary>
/// <returns>The protobuf version of the AgentInfoActionPairProto.</returns>

/// <returns>The protobuf version of the AgentInfo.</returns>
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
if(ai.groupId > 0)
{
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups;
if (!trainerCanHandle)
{
if (!s_HaveWarnedTrainerCapabilitiesAgentGroup)
{
Debug.LogWarning(
$"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." +
"Please find the versions that work best together from our release page: " +
"https://github.com/Unity-Technologies/ml-agents/releases"
);
s_HaveWarnedTrainerCapabilitiesAgentGroup = true;
}
}
}
var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,

HybridActions = proto.HybridActions,
TrainingAnalytics = proto.TrainingAnalytics,
VariableLengthObservation = proto.VariableLengthObservation,
MultiAgentGroups = proto.MultiAgentGroups,
};
}

HybridActions = rlCaps.HybridActions,
TrainingAnalytics = rlCaps.TrainingAnalytics,
VariableLengthObservation = rlCaps.VariableLengthObservation,
MultiAgentGroups = rlCaps.MultiAgentGroups,
};
}

5
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool HybridActions;
public bool TrainingAnalytics;
public bool VariableLengthObservation;
public bool MultiAgentGroups;
/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This

bool compressedChannelMapping = true,
bool hybridActions = true,
bool trainingAnalytics = true,
bool variableLengthObservation = true)
bool variableLengthObservation = true,
bool multiAgentGroups = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;

VariableLengthObservation = variableLengthObservation;
MultiAgentGroups = multiAgentGroups;
}
/// <summary>

40
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi0gEKGFVuaXR5UkxD",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD",
"Z3RoT2JzZXJ2YXRpb24YBiABKAhCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11",
"bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo",
"CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv",
"dG8z"));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null)
}));
}
#endregion

hybridActions_ = other.hybridActions_;
trainingAnalytics_ = other.trainingAnalytics_;
variableLengthObservation_ = other.variableLengthObservation_;
multiAgentGroups_ = other.multiAgentGroups_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "multiAgentGroups" field.</summary>
public const int MultiAgentGroupsFieldNumber = 7;
private bool multiAgentGroups_;
/// <summary>
/// Support for multi agent groups and group rewards
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool MultiAgentGroups {
get { return multiAgentGroups_; }
set {
multiAgentGroups_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (HybridActions != other.HybridActions) return false;
if (TrainingAnalytics != other.TrainingAnalytics) return false;
if (VariableLengthObservation != other.VariableLengthObservation) return false;
if (MultiAgentGroups != other.MultiAgentGroups) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode();
if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

if (VariableLengthObservation != false) {
output.WriteRawTag(48);
output.WriteBool(VariableLengthObservation);
}
if (MultiAgentGroups != false) {
output.WriteRawTag(56);
output.WriteBool(MultiAgentGroups);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

if (VariableLengthObservation != false) {
size += 1 + 1;
}
if (MultiAgentGroups != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other.VariableLengthObservation != false) {
VariableLengthObservation = other.VariableLengthObservation;
}
if (other.MultiAgentGroups != false) {
MultiAgentGroups = other.MultiAgentGroups;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 48: {
VariableLengthObservation = input.ReadBool();
break;
}
case 56: {
MultiAgentGroups = input.ReadBool();
break;
}
}

11
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xd2\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xec\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x12\x18\n\x10multiAgentGroups\x18\x07 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='multiAgentGroups', full_name='communicator_objects.UnityRLCapabilitiesProto.multiAgentGroups', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=80,
serialized_end=290,
serialized_end=316,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


hybridActions = ... # type: builtin___bool
trainingAnalytics = ... # type: builtin___bool
variableLengthObservation = ... # type: builtin___bool
multiAgentGroups = ... # type: builtin___bool
def __init__(self,
*,

hybridActions : typing___Optional[builtin___bool] = None,
trainingAnalytics : typing___Optional[builtin___bool] = None,
variableLengthObservation : typing___Optional[builtin___bool] = None,
multiAgentGroups : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"multiAgentGroups",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"multiAgentGroups",b"multiAgentGroups",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...

3
ml-agents-envs/mlagents_envs/environment.py


# * 1.2.0 - support compression mapping for stacked compressed observations.
# * 1.3.0 - support action spaces with both continuous and discrete actions.
# * 1.4.0 - support training analytics sent from python trainer to the editor.
# * 1.5.0 - support variable length observation training.
# * 1.5.0 - support variable length observation training and multi-agent groups.
API_VERSION = "1.5.0"
# Default port that the editor listens on. If an environment executable

capabilities.hybridActions = True
capabilities.trainingAnalytics = True
capabilities.variableLengthObservation = True
capabilities.multiAgentGroups = True
return capabilities
@staticmethod

3
protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto


// Support for variable length observations of rank 2
bool variableLengthObservation = 6;
// Support for multi agent groups and group rewards
bool multiAgentGroups = 7;
}
正在加载...
取消
保存