浏览代码

Increase communicator version for concatenated PNGs. (#4462)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
9d840f22
共有 10 个文件被更改,包括 196 次插入17 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 15
      com.unity.ml-agents/Runtime/Academy.cs
  3. 29
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  4. 10
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  5. 40
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  6. 90
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  7. 11
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  8. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  9. 6
      ml-agents-envs/mlagents_envs/environment.py
  10. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto

3
com.unity.ml-agents/CHANGELOG.md


- Compressed visual observations with >3 channels are now supported. In
ISensor.GetCompressedObservation(), this can be done by writing 3 channels at a
time to a PNG and concatenating the resulting bytes. (#4399)
- The Communication API was changed to 1.1.0 to indicate support for concatenated PNGs
(see above). Newer versions of the package that wish to make use of this will also need
a compatible version of the trainer.
- A CNN (`vis_encode_type: match3`) for smaller grids, e.g. board games, has been added.
(#4434)
- You can now again specify a default configuration for your behaviors. Specify `default_settings` in

15
com.unity.ml-agents/Runtime/Academy.cs


/// functionality will work as long the major versions match.
/// This should be changed whenever a change is made to the communication protocol.
/// </summary>
const string k_ApiVersion = "1.0.0";
/// <remarks>
/// History:
/// <list type="bullet">
/// <item>
/// <term>1.0.0</term>
/// <description>Initial version</description>
/// </item>
/// <item>
/// <term>1.1.0</term>
/// <description>Support concatenated PNGs for compressed observations.</description>
/// </item>
/// </list>
/// </remarks>
const string k_ApiVersion = "1.1.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

29
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


}
/// <summary>
/// Static flag to make sure that we only fire the warning once.
/// </summary>
private static bool s_HaveWarnedAboutTrainerCapabilities = false;
/// <summary>
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.
/// This is equivalent to producing an Observation and calling Observation.ToProto(),
/// but avoid some intermediate memory allocations.

{
var shape = sensor.GetObservationShape();
ObservationProto observationProto = null;
if (sensor.GetCompressionType() == SensorCompressionType.None)
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs
if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
{
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
if (!trainerCanHandle)
{
if (!s_HaveWarnedAboutTrainerCapabilities)
{
Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
s_HaveWarnedAboutTrainerCapabilities = true;
}
compressionType = SensorCompressionType.None;
}
}
if (compressionType == SensorCompressionType.None)
{
var numFloats = sensor.ObservationSize();
var floatDataProto = new ObservationProto.Types.FloatData();

{
return new UnityRLCapabilities
{
m_BaseRLCapabilities = proto.BaseRLCapabilities
BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations
};
}

{
BaseRLCapabilities = rlCaps.m_BaseRLCapabilities
BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
};
}
}

10
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


{
internal class UnityRLCapabilities
{
internal bool m_BaseRLCapabilities;
public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public UnityRLCapabilities(bool baseRlCapabilities = true)
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true)
m_BaseRLCapabilities = baseRlCapabilities;
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
}
/// <summary>

/// <returns></returns>
public bool WarnOnPythonMissingBaseRLCapabilities()
{
if (m_BaseRLCapabilities)
if (BaseRLCapabilities)
{
return false;
}

40
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiNgoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCEIl",
"qgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5",
"Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null)
}));
}
#endregion

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "concatenatedPngObservations" field.</summary>
public const int ConcatenatedPngObservationsFieldNumber = 2;
private bool concatenatedPngObservations_;
/// <summary>
/// concatenated PNG files for compressed visual observations with >3 channels.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool ConcatenatedPngObservations {
get { return concatenatedPngObservations_; }
set {
concatenatedPngObservations_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

return true;
}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

if (BaseRLCapabilities != false) {
output.WriteRawTag(8);
output.WriteBool(BaseRLCapabilities);
}
if (ConcatenatedPngObservations != false) {
output.WriteRawTag(16);
output.WriteBool(ConcatenatedPngObservations);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

if (BaseRLCapabilities != false) {
size += 1 + 1;
}
if (ConcatenatedPngObservations != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other.BaseRLCapabilities != false) {
BaseRLCapabilities = other.BaseRLCapabilities;
}
if (other.ConcatenatedPngObservations != false) {
ConcatenatedPngObservations = other.ConcatenatedPngObservations;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 8: {
BaseRLCapabilities = input.ReadBool();
break;
}
case 16: {
ConcatenatedPngObservations = input.ReadBool();
break;
}
}

90
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using Unity.MLAgents.Policies;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{

// Should be able to convert a default instance to proto.
var demoMetaData = new DemonstrationMetaData();
demoMetaData.ToProto();
}
class DummySensor : ISensor
{
public int[] Shape;
public SensorCompressionType CompressionType;
internal DummySensor()
{
}
public int[] GetObservationShape()
{
return Shape;
}
public int Write(ObservationWriter writer)
{
return 0;
}
public byte[] GetCompressedObservation()
{
return new byte[] { 13, 37 };
}
public void Update() { }
public void Reset() { }
public SensorCompressionType GetCompressionType()
{
return CompressionType;
}
public string GetName()
{
return "Dummy";
}
}
[Test]
public void TestGetObservationProtoCapabilities()
{
// Shape, compression type, concatenatedPngObservations, expect throw
var variants = new[]
{
// Vector observations
(new[] {3}, SensorCompressionType.None, false, false),
// Uncompressed floats
(new[] {4, 4, 3}, SensorCompressionType.None, false, false),
// Compressed floats, 3 channels
(new[] {4, 4, 3}, SensorCompressionType.PNG, false, true),
// Compressed floats, >3 channels
(new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed
(new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed
};
foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var dummySensor = new DummySensor();
var obsWriter = new ObservationWriter();
dummySensor.Shape = shape;
dummySensor.CompressionType = compressionType;
obsWriter.SetTarget(new float[128], shape, 0);
var caps = new UnityRLCapabilities
{
ConcatenatedPngObservations = supportsMultiPngObs
};
Academy.Instance.TrainerCapabilities = caps;
var obsProto = dummySensor.GetObservationProto(obsWriter);
if (expectCompressed)
{
Assert.Greater(obsProto.CompressedData.Length, 0);
Assert.AreEqual(obsProto.FloatData, null);
}
else
{
Assert.Greater(obsProto.FloatData.Data.Count, 0);
Assert.AreEqual(obsProto.CompressedData.Length, 0);
}
}
}
}
}

11
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"6\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"[\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='concatenatedPngObservations', full_name='communicator_objects.UnityRLCapabilitiesProto.concatenatedPngObservations', index=1,
number=2, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=79,
serialized_end=133,
serialized_end=170,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


class UnityRLCapabilitiesProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
baseRLCapabilities = ... # type: builtin___bool
concatenatedPngObservations = ... # type: builtin___bool
concatenatedPngObservations : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...

6
ml-agents-envs/mlagents_envs/environment.py


# We follow semantic versioning on the communication version, so existing
# functionality will work as long the major versions match.
# This should be changed whenever a change is made to the communication protocol.
API_VERSION = "1.0.0"
# Revision history:
# * 1.0.0 - initial version
# * 1.1.0 - support concatenated PNGs for compressed observations.
API_VERSION = "1.1.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.

def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities = UnityRLCapabilitiesProto()
capabilities.baseRLCapabilities = True
capabilities.concatenatedPngObservations = True
return capabilities
@staticmethod

3
protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto


message UnityRLCapabilitiesProto {
// These are the 1.0 capabilities.
bool baseRLCapabilities = 1;
// concatenated PNG files for compressed visual observations with >3 channels.
bool concatenatedPngObservations = 2;
}
正在加载...
取消
保存