浏览代码

Develop SideChannel new api (#3425)

* Make ChannelId a property and renamed ReservedChannelId

* Changes on the Python side for consistency

* Modified the tutorial appropriately

* fixing bugs

* Update ml-agents-envs/mlagents_envs/environment.py

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Addressing comments

* Update docs/Python-API.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Added a Utils class on the side channel (#3447)

- No change in user facing API
 - Simplifies the code in the side channel implementations as it makes it easier to check if a side channel id is within ranges
 - No changes to tests
 - No changes to Documentation

* Simplifying

* Fixing a bug

* Replace the int ChannelId with a GUID/UUID ChannelId (#3454)

* renaming channel_type to channel_id

* Making the constant GUID const...
/asymm-envs
GitHub 5 年前
当前提交
c38dd44c
共有 14 个文件被更改,包括 136 次插入141 次删除
  1. 32
      com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs
  2. 6
      com.unity.ml-agents/Runtime/SideChannel/EngineConfigurationChannel.cs
  3. 11
      com.unity.ml-agents/Runtime/SideChannel/FloatPropertiesChannel.cs
  4. 12
      com.unity.ml-agents/Runtime/SideChannel/RawBytesChannel.cs
  5. 20
      com.unity.ml-agents/Runtime/SideChannel/SideChannel.cs
  6. 20
      com.unity.ml-agents/Tests/Editor/SideChannelTests.cs
  7. 35
      docs/Python-API.md
  8. 36
      ml-agents-envs/mlagents_envs/environment.py
  9. 8
      ml-agents-envs/mlagents_envs/exception.py
  10. 8
      ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
  11. 13
      ml-agents-envs/mlagents_envs/side_channel/float_properties_channel.py
  12. 16
      ml-agents-envs/mlagents_envs/side_channel/raw_bytes_channel.py
  13. 24
      ml-agents-envs/mlagents_envs/side_channel/side_channel.py
  14. 36
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py

32
com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs


/// The communicator parameters sent at construction
CommunicatorInitParameters m_CommunicatorInitParameters;
Dictionary<int, SideChannel> m_SideChannels = new Dictionary<int, SideChannel>();
Dictionary<Guid, SideChannel> m_SideChannels = new Dictionary<Guid, SideChannel>();
/// <summary>
/// Initializes a new instance of the RPCCommunicator class.

/// <param name="sideChannel"> The side channel to be registered.</param>
public void RegisterSideChannel(SideChannel sideChannel)
{
var channelType = sideChannel.ChannelType();
if (m_SideChannels.ContainsKey(channelType))
var channelId = sideChannel.ChannelId;
if (m_SideChannels.ContainsKey(channelId))
"side channels of the same type.", channelType));
"side channels of the same id.", channelId));
m_SideChannels.Add(channelType, sideChannel);
m_SideChannels.Add(channelId, sideChannel);
}
/// <summary>

public void UnregisterSideChannel(SideChannel sideChannel)
{
if (m_SideChannels.ContainsKey(sideChannel.ChannelType()))
if (m_SideChannels.ContainsKey(sideChannel.ChannelId))
m_SideChannels.Remove(sideChannel.ChannelType());
m_SideChannels.Remove(sideChannel.ChannelId);
}
}

/// </summary>
/// <param name="sideChannels"> A dictionary of channel type to channel.</param>
/// <returns></returns>
public static byte[] GetSideChannelMessage(Dictionary<int, SideChannel> sideChannels)
public static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels)
{
using (var memStream = new MemoryStream())
{

var messageList = sideChannel.MessageQueue;
foreach (var message in messageList)
{
binaryWriter.Write(sideChannel.ChannelType());
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Count());
binaryWriter.Write(message);
}

/// </summary>
/// <param name="sideChannels">A dictionary of channel type to channel.</param>
/// <param name="dataReceived">The byte array of data received from Python.</param>
public static void ProcessSideChannelData(Dictionary<int, SideChannel> sideChannels, byte[] dataReceived)
public static void ProcessSideChannelData(Dictionary<Guid, SideChannel> sideChannels, byte[] dataReceived)
{
if (dataReceived.Length == 0)
{

{
while (memStream.Position < memStream.Length)
{
int channelType = 0;
Guid channelId = Guid.Empty;
channelType = binaryReader.ReadInt32();
channelId = new Guid(binaryReader.ReadBytes(16));
var messageLength = binaryReader.ReadInt32();
message = binaryReader.ReadBytes(messageLength);
}

"version of MLAgents in Unity is compatible with the Python version. Original error : "
+ ex.Message);
}
if (sideChannels.ContainsKey(channelType))
if (sideChannels.ContainsKey(channelId))
sideChannels[channelType].OnMessageReceived(message);
sideChannels[channelId].OnMessageReceived(message);
"Unknown side channel data received. Channel type "
+ ": {0}", channelType));
"Unknown side channel data received. Channel Id is "
+ ": {0}", channelId));
}
}
}

6
com.unity.ml-agents/Runtime/SideChannel/EngineConfigurationChannel.cs


using System.IO;
using System;
using UnityEngine;
namespace MLAgents

public override int ChannelType()
private const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7";
public EngineConfigurationChannel()
return (int)SideChannelType.EngineSettings;
ChannelId = new Guid(k_EngineConfigId);
}
public override void OnMessageReceived(byte[] data)

11
com.unity.ml-agents/Runtime/SideChannel/FloatPropertiesChannel.cs


{
Dictionary<string, float> m_FloatProperties = new Dictionary<string, float>();
Dictionary<string, Action<float>> m_RegisteredActions = new Dictionary<string, Action<float>>();
private const string k_FloatPropertiesDefaultId = "60ccf7d0-4f7e-11ea-b238-784f4387d1f7";
public override int ChannelType()
public FloatPropertiesChannel(Guid channelId = default(Guid))
return (int)SideChannelType.FloatProperties;
if (channelId == default(Guid))
{
ChannelId = new Guid(k_FloatPropertiesDefaultId);
}
else{
ChannelId = channelId;
}
}
public override void OnMessageReceived(byte[] data)

12
com.unity.ml-agents/Runtime/SideChannel/RawBytesChannel.cs


using System.Collections.Generic;
using System;
int m_ChannelId;
/// <summary>
/// RawBytesChannel provides a way to exchange raw byte arrays between Unity and Python.

public RawBytesChannel(int channelId = 0)
{
m_ChannelId = channelId;
}
public override int ChannelType()
public RawBytesChannel(Guid channelId)
return (int)SideChannelType.RawBytesChannelStart + m_ChannelId;
ChannelId = channelId;
}
public override void OnMessageReceived(byte[] data)

20
com.unity.ml-agents/Runtime/SideChannel/SideChannel.cs


using System.Collections.Generic;
using System;
public enum SideChannelType
{
// Invalid side channel
Invalid = 0,
// Reserved for the FloatPropertiesChannel.
FloatProperties = 1,
//Reserved for the EngineConfigurationChannel.
EngineSettings = 2,
// Raw bytes channels should start here to avoid conflicting with other Unity ones.
RawBytesChannelStart = 1000,
// custom side channels should start here to avoid conflicting with Unity ones.
UserSideChannelStart = 2000,
}
public abstract class SideChannel
{
// The list of messages (byte arrays) that need to be sent to Python via the communicator.

/// of each type. Ensure the Unity side channels will be linked to their Python equivalent.
/// </summary>
/// <returns> The integer identifier of the SideChannel</returns>
public abstract int ChannelType();
public Guid ChannelId{
get;
protected set;
}
/// <summary>
/// Is called by the communicator every time a message is received from Python by the SideChannel.

20
com.unity.ml-agents/Tests/Editor/SideChannelTests.cs


{
public List<int> messagesReceived = new List<int>();
public override int ChannelType() { return -1; }
public TestSideChannel() {
ChannelId = new Guid("6afa2c06-4f82-11ea-b238-784f4387d1f7");
}
public override void OnMessageReceived(byte[] data)
{

{
var intSender = new TestSideChannel();
var intReceiver = new TestSideChannel();
var dictSender = new Dictionary<int, SideChannel> { { intSender.ChannelType(), intSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { intReceiver.ChannelType(), intReceiver } };
var dictSender = new Dictionary<Guid, SideChannel> { { intSender.ChannelId, intSender } };
var dictReceiver = new Dictionary<Guid, SideChannel> { { intReceiver.ChannelId, intReceiver } };
intSender.SendInt(4);
intSender.SendInt(5);

var str1 = "Test string";
var str2 = "Test string, second";
var strSender = new RawBytesChannel();
var strReceiver = new RawBytesChannel();
var dictSender = new Dictionary<int, SideChannel> { { strSender.ChannelType(), strSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { strReceiver.ChannelType(), strReceiver } };
var strSender = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
var strReceiver = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
var dictSender = new Dictionary<Guid, SideChannel> { { strSender.ChannelId, strSender } };
var dictReceiver = new Dictionary<Guid, SideChannel> { { strReceiver.ChannelId, strReceiver } };
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1));
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2));

var propA = new FloatPropertiesChannel();
var propB = new FloatPropertiesChannel();
var dictReceiver = new Dictionary<int, SideChannel> { { propA.ChannelType(), propA } };
var dictSender = new Dictionary<int, SideChannel> { { propB.ChannelType(), propB } };
var dictReceiver = new Dictionary<Guid, SideChannel> { { propA.ChannelId, propA } };
var dictSender = new Dictionary<Guid, SideChannel> { { propB.ChannelId, propB } };
propA.RegisterCallback(k1, f => { wasCalled++; });
var tmp = propB.GetPropertyWithDefault(k2, 3.0f);

35
docs/Python-API.md


You can create your own `SideChannel` in C# and Python and use it to communicate data between the two.
##### Unity side
The side channel will have to implement the `SideChannel` abstract class. There are two methods
that must be implemented :
The side channel will have to implement the `SideChannel` abstract class and the following method.
* `ChannelType()` : Must return an integer identifying the side channel (This number must be the same on C#
and Python). There can only be one side channel of a certain type during communication.
The side channel must also assign a `ChannelId` property in the constructor. The `ChannelId` is a Guid
(or UUID in Python) used to uniquely identify a side channel. This Guid must be the same on C# and Python.
There can only be one side channel of a certain id during communication.
To send a byte array from C# to Python, call the `base.QueueMessageToSend(data)` method inside the side channel.
The `data` argument must be a `byte[]`.

##### Python side
The side channel will have to implement the `SideChannel` abstract class. You must implement :
* `channel_type(self) -> int` (property) : Must return an integer identifying the side channel (This number must
be the same on C# and Python). There can only be one side channel of a certain type during communication.
The side channel must also assign a `channel_id` property in the constructor. The `channel_id` is a UUID
(referred in C# as Guid) used to uniquely identify a side channel. This number must be the same on C# and
Python. There can only be one side channel of a certain id during communication.
To assign the `channel_id` call the abstract class constructor with the appropriate `channel_id` as follows:
```python
super().__init__(my_channel_id)
```
To send a byte array from Python to C#, call the `super().queue_message_to_send(bytes_data)` method inside the
side channel. The `bytes_data` argument must be a `bytes` object.

using UnityEngine;
using MLAgents;
using System.Text;
using System;
public override int ChannelType()
public StringLogSideChannel()
return (int)SideChannelType.UserSideChannelStart + 1;
ChannelId = new Guid("621f0a70-4f87-11ea-a6bf-784f4387d1f7");
}
public override void OnMessageReceived(byte[] data)

```python
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.side_channel import SideChannel, SideChannelType
from mlagents_envs.side_channel.side_channel import SideChannel
@property
def channel_type(self) -> int:
return SideChannelType.UserSideChannelStart + 1
def __init__(self) -> None:
super().__init__(uuid.UUID("621f0a70-4f87-11ea-a6bf-784f4387d1f7"))
def on_message_received(self, data: bytes) -> None:
"""

env.step() # Move the simulation forward
env.close()
```
Now, if you run this script and press `Play` the Unity Editor when prompted, The console in the Unity Editor will

36
ml-agents-envs/mlagents_envs/environment.py


import atexit
import glob
import uuid
import logging
import numpy as np
import os

self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
self.side_channels: Dict[int, SideChannel] = {}
self.side_channels: Dict[uuid.UUID, SideChannel] = {}
if _sc.channel_type in self.side_channels:
if _sc.channel_id in self.side_channels:
"There cannot be two side channels with the same channel type {0}.".format(
_sc.channel_type
"There cannot be two side channels with the same channel id {0}.".format(
_sc.channel_id
self.side_channels[_sc.channel_type] = _sc
self.side_channels[_sc.channel_id] = _sc
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.

@staticmethod
def _parse_side_channel_message(
side_channels: Dict[int, SideChannel], data: bytes
side_channels: Dict[uuid.UUID, SideChannel], data: bytes
channel_type, message_len = struct.unpack_from("<ii", data, offset)
offset = offset + 8
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except Exception:

raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_type)
"sending side channel data properly.".format(channel_id)
if channel_type in side_channels:
side_channels[channel_type].on_message_received(message_data)
if channel_id in side_channels:
side_channels[channel_id].on_message_received(message_data)
": {0}.".format(channel_type)
": {0}.".format(channel_id)
def _generate_side_channel_data(side_channels: Dict[int, SideChannel]) -> bytearray:
def _generate_side_channel_data(
side_channels: Dict[uuid.UUID, SideChannel]
) -> bytearray:
for channel_type, channel in side_channels.items():
for channel_id, channel in side_channels.items():
result += struct.pack("<ii", channel_type, len(message))
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result

8
ml-agents-envs/mlagents_envs/exception.py


pass
class UnitySideChannelException(UnityException):
"""
Related to errors with side channels.
"""
pass
class UnityWorkerInUseException(UnityException):
"""
This error occurs when the port for a certain worker ID is already reserved.

8
ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel, SideChannelType
from mlagents_envs.side_channel.side_channel import SideChannel
import uuid
from typing import NamedTuple

- int targetFrameRate;
"""
@property
def channel_type(self) -> int:
return SideChannelType.EngineSettings
def __init__(self) -> None:
super().__init__(uuid.UUID("e951342c-4f7e-11ea-b238-784f4387d1f7"))
def on_message_received(self, data: bytes) -> None:
"""

13
ml-agents-envs/mlagents_envs/side_channel/float_properties_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel, SideChannelType
from mlagents_envs.side_channel.side_channel import SideChannel
import uuid
from typing import Dict, Tuple, Optional, List

set_property, get_property and list_properties.
"""
def __init__(self) -> None:
def __init__(self, channel_id: uuid.UUID = None) -> None:
super().__init__()
@property
def channel_type(self) -> int:
return SideChannelType.FloatProperties
if channel_id is None:
channel_id = uuid.UUID(("60ccf7d0-4f7e-11ea-b238-784f4387d1f7"))
super().__init__(channel_id)
def on_message_received(self, data: bytes) -> None:
"""

16
ml-agents-envs/mlagents_envs/side_channel/raw_bytes_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel, SideChannelType
from mlagents_envs.side_channel.side_channel import SideChannel
import uuid
class RawBytesChannel(SideChannel):

"""
def __init__(self, channel_id=0):
self._received_messages = []
self._channel_id = channel_id
super().__init__()
@property
def channel_type(self) -> int:
return SideChannelType.RawBytesChannelStart + self._channel_id
def __init__(self, channel_id: uuid.UUID):
self._received_messages: List[bytes] = []
super().__init__(channel_id)
def on_message_received(self, data: bytes) -> None:
"""

"""
self._received_messages.append(data)
def get_and_clear_received_messages(self) -> List[bytearray]:
def get_and_clear_received_messages(self) -> List[bytes]:
"""
returns a list of bytearray received from the environment.
"""

24
ml-agents-envs/mlagents_envs/side_channel/side_channel.py


from abc import ABC, abstractmethod
from enum import IntEnum
class SideChannelType(IntEnum):
FloatProperties = 1
EngineSettings = 2
# Raw bytes channels should start here to avoid conflicting with other
# Unity ones.
RawBytesChannelStart = 1000
# custom side channels should start here to avoid conflicting with Unity
# ones.
UserSideChannelStart = 2000
from typing import List
import uuid
class SideChannel(ABC):

to the Env object at construction.
"""
def __init__(self):
self.message_queue = []
def __init__(self, channel_id):
self._channel_id: uuid.UUID = channel_id
self.message_queue: List[bytearray] = []
def queue_message_to_send(self, data: bytearray) -> None:
"""

pass
@property
@abstractmethod
def channel_type(self) -> int:
def channel_id(self) -> uuid.UUID:
pass
return self._channel_id

36
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


import struct
import uuid
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel

class IntChannel(SideChannel):
def __init__(self):
self.list_int = []
super().__init__()
@property
def channel_type(self):
return -1
super().__init__(uuid.UUID("a85ba5c0-4f87-11ea-a517-784f4387d1f7"))
def on_message_received(self, data):
val = struct.unpack_from("<i", data, 0)[0]

receiver = IntChannel()
sender.send_int(5)
sender.send_int(6)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
assert receiver.list_int[0] == 5
assert receiver.list_int[1] == 6

sender.set_property("prop1", 1.0)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
val = receiver.get_property("prop1")
assert val == 1.0

data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
val = receiver.get_property("prop1")
assert val == 1.0

def test_raw_bytes():
sender = RawBytesChannel()
receiver = RawBytesChannel()
guid = uuid.uuid4()
sender = RawBytesChannel(guid)
receiver = RawBytesChannel(guid)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
messages = receiver.get_and_clear_received_messages()
assert len(messages) == 2

正在加载...
取消
保存