浏览代码

Merge pull request #3596 from Unity-Technologies/develop-sidechannel-usability

SideChannel helpers
/bug-failed-api-check
GitHub 5 年前
当前提交
6b846478
共有 20 个文件被更改,包括 619 次插入161 次删除
  1. 16
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  2. 31
      com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs
  3. 69
      com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs
  4. 10
      com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs
  5. 13
      com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs
  6. 69
      com.unity.ml-agents/Tests/Editor/SideChannelTests.cs
  7. 49
      docs/Custom-SideChannels.md
  8. 5
      ml-agents-envs/mlagents_envs/environment.py
  9. 4
      ml-agents-envs/mlagents_envs/side_channel/__init__.py
  10. 23
      ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
  11. 36
      ml-agents-envs/mlagents_envs/side_channel/float_properties_channel.py
  12. 10
      ml-agents-envs/mlagents_envs/side_channel/raw_bytes_channel.py
  13. 13
      ml-agents-envs/mlagents_envs/side_channel/side_channel.py
  14. 68
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py
  15. 101
      com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs
  16. 11
      com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta
  17. 110
      com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs
  18. 11
      com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta
  19. 65
      ml-agents-envs/mlagents_envs/side_channel/incoming_message.py
  20. 66
      ml-agents-envs/mlagents_envs/side_channel/outgoing_message.py

16
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


"side channels of the same id.", channelId));
}
// Process any messages that we've already received for this channel ID.
var numMessages = m_CachedMessages.Count;
for (int i = 0; i < numMessages; i++)
{

sideChannel.OnMessageReceived(cachedMessage.Message);
using (var incomingMsg = new IncomingMessage(cachedMessage.Message))
{
sideChannel.OnMessageReceived(incomingMsg);
}
}
else
{

var cachedMessage = m_CachedMessages.Dequeue();
if (sideChannels.ContainsKey(cachedMessage.ChannelId))
{
sideChannels[cachedMessage.ChannelId].OnMessageReceived(cachedMessage.Message);
using (var incomingMsg = new IncomingMessage(cachedMessage.Message))
{
sideChannels[cachedMessage.ChannelId].OnMessageReceived(incomingMsg);
}
}
else
{

}
if (sideChannels.ContainsKey(channelId))
{
sideChannels[channelId].OnMessageReceived(message);
using (var incomingMsg = new IncomingMessage(message))
{
sideChannels[channelId].OnMessageReceived(incomingMsg);
}
}
else
{

31
com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs


using System.IO;
using System;
using UnityEngine;

}
/// <inheritdoc/>
public override void OnMessageReceived(byte[] data)
public override void OnMessageReceived(IncomingMessage msg)
using (var memStream = new MemoryStream(data))
{
using (var binaryReader = new BinaryReader(memStream))
{
var width = binaryReader.ReadInt32();
var height = binaryReader.ReadInt32();
var qualityLevel = binaryReader.ReadInt32();
var timeScale = binaryReader.ReadSingle();
var targetFrameRate = binaryReader.ReadInt32();
var width = msg.ReadInt32();
var height = msg.ReadInt32();
var qualityLevel = msg.ReadInt32();
var timeScale = msg.ReadFloat32();
var targetFrameRate = msg.ReadInt32();
timeScale = Mathf.Clamp(timeScale, 1, 100);
timeScale = Mathf.Clamp(timeScale, 1, 100);
Screen.SetResolution(width, height, false);
QualitySettings.SetQualityLevel(qualityLevel, true);
Time.timeScale = timeScale;
Time.captureFramerate = 60;
Application.targetFrameRate = targetFrameRate;
}
}
Screen.SetResolution(width, height, false);
QualitySettings.SetQualityLevel(qualityLevel, true);
Time.timeScale = timeScale;
Time.captureFramerate = 60;
Application.targetFrameRate = targetFrameRate;
}
}
}

69
com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs


using System.Collections.Generic;
using System.IO;
using System.Text;
namespace MLAgents.SideChannels
{

}
/// <inheritdoc/>
public override void OnMessageReceived(byte[] data)
public override void OnMessageReceived(IncomingMessage msg)
var kv = DeserializeMessage(data);
m_FloatProperties[kv.Key] = kv.Value;
if (m_RegisteredActions.ContainsKey(kv.Key))
{
m_RegisteredActions[kv.Key].Invoke(kv.Value);
}
var key = msg.ReadString();
var value = msg.ReadFloat32();
m_FloatProperties[key] = value;
Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
/// <inheritdoc/>

QueueMessageToSend(SerializeMessage(key, value));
if (m_RegisteredActions.ContainsKey(key))
using (var msgOut = new OutgoingMessage())
m_RegisteredActions[key].Invoke(value);
msgOut.WriteString(key);
msgOut.WriteFloat32(value);
QueueMessageToSend(msgOut);
Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
if (m_FloatProperties.ContainsKey(key))
{
return m_FloatProperties[key];
}
else
{
return defaultValue;
}
float valueOut;
bool hasKey = m_FloatProperties.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
}
/// <inheritdoc/>

public IList<string> ListProperties()
{
return new List<string>(m_FloatProperties.Keys);
}
static KeyValuePair<string, float> DeserializeMessage(byte[] data)
{
using (var memStream = new MemoryStream(data))
{
using (var binaryReader = new BinaryReader(memStream))
{
var keyLength = binaryReader.ReadInt32();
var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength));
var value = binaryReader.ReadSingle();
return new KeyValuePair<string, float>(key, value);
}
}
}
static byte[] SerializeMessage(string key, float value)
{
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
var stringEncoded = Encoding.ASCII.GetBytes(key);
binaryWriter.Write(stringEncoded.Length);
binaryWriter.Write(stringEncoded);
binaryWriter.Write(value);
return memStream.ToArray();
}
}
}
}
}

10
com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs


}
/// <inheritdoc/>
public override void OnMessageReceived(byte[] data)
public override void OnMessageReceived(IncomingMessage msg)
m_MessagesReceived.Add(data);
m_MessagesReceived.Add(msg.GetRawBytes());
}
/// <summary>

/// <param name="data"> The byte array of data to send to Python.</param>
public void SendRawBytes(byte[] data)
{
QueueMessageToSend(data);
using (var msg = new OutgoingMessage())
{
msg.SetRawBytes(data);
QueueMessageToSend(msg);
}
}
/// <summary>

13
com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs


using System.Collections.Generic;
using System;
using System.IO;
using System.Text;
namespace MLAgents.SideChannels
{

/// of each type. Ensure the Unity side channels will be linked to their Python equivalent.
/// </summary>
/// <returns> The integer identifier of the SideChannel.</returns>
public Guid ChannelId{
public Guid ChannelId
{
get;
protected set;
}

/// Can be called multiple times per simulation step if multiple messages were sent.
/// </summary>
/// <param name="data"> the payload of the message.</param>
public abstract void OnMessageReceived(byte[] data);
/// <param name="msg">The incoming message.</param>
public abstract void OnMessageReceived(IncomingMessage msg);
protected void QueueMessageToSend(byte[] data)
protected void QueueMessageToSend(OutgoingMessage msg)
MessageQueue.Add(data);
MessageQueue.Add(msg.ToByteArray());
}
}
}

69
com.unity.ml-agents/Tests/Editor/SideChannelTests.cs


{
public List<int> messagesReceived = new List<int>();
public TestSideChannel() {
public TestSideChannel()
{
public override void OnMessageReceived(byte[] data)
public override void OnMessageReceived(IncomingMessage msg)
messagesReceived.Add(BitConverter.ToInt32(data, 0));
messagesReceived.Add(msg.ReadInt32());
public void SendInt(int data)
public void SendInt(int value)
QueueMessageToSend(BitConverter.GetBytes(data));
using (var msg = new OutgoingMessage())
{
msg.WriteInt32(value);
QueueMessageToSend(msg);
}
}
}

fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
Assert.AreEqual(wasCalled, 1);
var keysA = propA.ListProperties();
Assert.AreEqual(2, keysA.Count);
Assert.IsTrue(keysA.Contains(k1));
Assert.IsTrue(keysA.Contains(k2));
var keysB = propA.ListProperties();
Assert.AreEqual(2, keysB.Count);
Assert.IsTrue(keysB.Contains(k1));
Assert.IsTrue(keysB.Contains(k2));
}
[Test]
public void TestOutgoingMessageRawBytes()
{
// Make sure that SetRawBytes resets the buffer correctly.
// Write 8 bytes (an int and float) then call SetRawBytes with 4 bytes
var msg = new OutgoingMessage();
msg.WriteInt32(42);
msg.WriteFloat32(1.0f);
var data = new byte[] { 1, 2, 3, 4 };
msg.SetRawBytes(data);
var result = msg.ToByteArray();
Assert.AreEqual(data, result);
}
[Test]
public void TestMessageReadWrites()
{
var boolVal = true;
var intVal = 1337;
var floatVal = 4.2f;
var floatListVal = new float[] { 1001, 1002 };
var stringVal = "mlagents!";
IncomingMessage incomingMsg;
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteBoolean(boolVal);
outgoingMsg.WriteInt32(intVal);
outgoingMsg.WriteFloat32(floatVal);
outgoingMsg.WriteString(stringVal);
outgoingMsg.WriteFloatList(floatListVal);
incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray());
}
Assert.AreEqual(boolVal, incomingMsg.ReadBoolean());
Assert.AreEqual(intVal, incomingMsg.ReadInt32());
Assert.AreEqual(floatVal, incomingMsg.ReadFloat32());
Assert.AreEqual(stringVal, incomingMsg.ReadString());
Assert.AreEqual(floatListVal, incomingMsg.ReadFloatList());
}
}
}

49
docs/Custom-SideChannels.md


### Unity side
The side channel will have to implement the `SideChannel` abstract class and the following method.
* `OnMessageReceived(byte[] data)` : You must implement this method to specify what the side channel will be doing
with the data received from Python. The data is a `byte[]` argument.
* `OnMessageReceived(IncomingMessage msg)` : You must implement this method and read the data from IncomingMessage.
The data must be read in the order that it was written.
To send a byte array from C# to Python, call the `base.QueueMessageToSend(data)` method inside the side channel.
The `data` argument must be a `byte[]`.
To send data from C# to Python, create an `OutgoingMessage` instance, add data to it, call the
`base.QueueMessageToSend(msg)` method inside the side channel, and call the
`OutgoingMessage.Dispose()` method.
To register a side channel on the Unity side, call `Academy.Instance.RegisterSideChannel` with the side channel
as only argument.

* `on_message_received(self, data: bytes) -> None` : You must implement this method to specify what the
side channel will be doing with the data received from Unity. The data is a `byte[]` argument.
* `on_message_received(self, msg: "IncomingMessage") -> None` : You must implement this method and read the data
from IncomingMessage. The data must be read in the order that it was written.
The side channel must also assign a `channel_id` property in the constructor. The `channel_id` is a UUID
(referred in C# as Guid) used to uniquely identify a side channel. This number must be the same on C# and

super().__init__(my_channel_id)
```
To send a byte array from Python to C#, call the `super().queue_message_to_send(bytes_data)` method inside the
side channel. The `bytes_data` argument must be a `bytes` object.
To send a byte array from Python to C#, create an `OutgoingMessage` instance, add data to it, and call the
`super().queue_message_to_send(msg)` method inside the side channel.
To register a side channel on the Python side, pass the side channel as argument when creating the
`UnityEnvironment` object. One of the arguments of the constructor (`side_channels`) is a list of side channels.

ChannelId = new Guid("621f0a70-4f87-11ea-a6bf-784f4387d1f7");
}
public override void OnMessageReceived(byte[] data)
public override void OnMessageReceived(IncomingMessage msg)
var receivedString = Encoding.ASCII.GetString(data);
var receivedString = msg.ReadString();
Debug.Log("From Python : " + receivedString);
}

{
var stringToSend = type.ToString() + ": " + logString + "\n" + stackTrace;
var encodedString = Encoding.ASCII.GetBytes(stringToSend);
base.QueueMessageToSend(encodedString);
using (var msgOut = new OutgoingMessage())
{
msgOut.WriteString(stringToSend);
QueueMessageToSend(msgOut);
}
}
}
}

```python
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import (
SideChannel,
IncomingMessage,
OutgoingMessage,
)
import numpy as np
import uuid

def __init__(self) -> None:
super().__init__(uuid.UUID("621f0a70-4f87-11ea-a6bf-784f4387d1f7"))
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
Note :We must implement this method of the SideChannel interface to
Note: We must implement this method of the SideChannel interface to
# We simply print the data received interpreted as ascii
print(data.decode("ascii"))
# We simply read a string from the message and print it.
print(msg.read_string())
# Convert the string to ascii
bytes_data = data.encode("ascii")
# Add the string to an OutgoingMessage
msg = OutgoingMessage()
msg.write_string(data)
super().queue_message_to_send(bytes_data)
super().queue_message_to_send(msg)
```

5
ml-agents-envs/mlagents_envs/environment.py


from typing import Dict, List, Optional, Any
import mlagents_envs
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import SideChannel, IncomingMessage
from mlagents_envs.base_env import (
BaseEnv,

"sending side channel data properly.".format(channel_id)
)
if channel_id in side_channels:
side_channels[channel_id].on_message_received(message_data)
incoming_message = IncomingMessage(message_data)
side_channels[channel_id].on_message_received(incoming_message)
else:
logger.warning(
"Unknown side channel data received. Channel type "

4
ml-agents-envs/mlagents_envs/side_channel/__init__.py


from mlagents_envs.side_channel.incoming_message import IncomingMessage # noqa
from mlagents_envs.side_channel.outgoing_message import OutgoingMessage # noqa
from mlagents_envs.side_channel.side_channel import SideChannel # noqa

23
ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel import SideChannel, OutgoingMessage, IncomingMessage
import struct
import uuid
from typing import NamedTuple

def __init__(self) -> None:
super().__init__(uuid.UUID("e951342c-4f7e-11ea-b238-784f4387d1f7"))
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that

:param target_frame_rate: Instructs simulation to try to render at a
specified frame rate. Default -1.
"""
data = bytearray()
data += struct.pack("<i", width)
data += struct.pack("<i", height)
data += struct.pack("<i", quality_level)
data += struct.pack("<f", time_scale)
data += struct.pack("<i", target_frame_rate)
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.write_int32(width)
msg.write_int32(height)
msg.write_int32(quality_level)
msg.write_float32(time_scale)
msg.write_int32(target_frame_rate)
super().queue_message_to_send(msg)
data = bytearray()
data += struct.pack("<iiifi", *config)
super().queue_message_to_send(data)
self.set_configuration_parameters(**config._asdict())

36
ml-agents-envs/mlagents_envs/side_channel/float_properties_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel
import struct
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from typing import Dict, Tuple, Optional, List
from typing import Dict, Optional, List
class FloatPropertiesChannel(SideChannel):

channel_id = uuid.UUID(("60ccf7d0-4f7e-11ea-b238-784f4387d1f7"))
super().__init__(channel_id)
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
Note that Python should never receive an engine configuration from
Unity
k, v = self.deserialize_float_prop(data)
k = msg.read_string()
v = msg.read_float32()
self._float_properties[k] = v
def set_property(self, key: str, value: float) -> None:

:param value: The float value of the property.
"""
self._float_properties[key] = value
super().queue_message_to_send(self.serialize_float_prop(key, value))
msg = OutgoingMessage()
msg.write_string(key)
msg.write_float32(value)
super().queue_message_to_send(msg)
def get_property(self, key: str) -> Optional[float]:
"""

:return:
"""
return dict(self._float_properties)
@staticmethod
def serialize_float_prop(key: str, value: float) -> bytearray:
result = bytearray()
encoded_key = key.encode("ascii")
result += struct.pack("<i", len(encoded_key))
result += encoded_key
result += struct.pack("<f", value)
return result
@staticmethod
def deserialize_float_prop(data: bytes) -> Tuple[str, float]:
offset = 0
encoded_key_len = struct.unpack_from("<i", data, offset)[0]
offset = offset + 4
key = data[offset : offset + encoded_key_len].decode("ascii")
offset = offset + encoded_key_len
value = struct.unpack_from("<f", data, offset)[0]
return key, value

10
ml-agents-envs/mlagents_envs/side_channel/raw_bytes_channel.py


from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from typing import List
import uuid

self._received_messages: List[bytes] = []
super().__init__(channel_id)
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
self._received_messages.append(data)
self._received_messages.append(msg.get_raw_bytes())
def get_and_clear_received_messages(self) -> List[bytes]:
"""

Queues a message to be sent by the environment at the next call to
step.
"""
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.set_raw_bytes(data)
super().queue_message_to_send(msg)

13
ml-agents-envs/mlagents_envs/side_channel/side_channel.py


from abc import ABC, abstractmethod
from typing import List
import uuid
import logging
from mlagents_envs.side_channel import IncomingMessage, OutgoingMessage
logger = logging.getLogger(__name__)
class SideChannel(ABC):

to the Env object at construction.
"""
def __init__(self, channel_id):
def __init__(self, channel_id: uuid.UUID):
def queue_message_to_send(self, data: bytearray) -> None:
def queue_message_to_send(self, msg: OutgoingMessage) -> None:
self.message_queue.append(data)
self.message_queue.append(msg.buffer)
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that

68
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


import struct
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents_envs.environment import UnityEnvironment

self.list_int = []
super().__init__(uuid.UUID("a85ba5c0-4f87-11ea-a517-784f4387d1f7"))
def on_message_received(self, data):
val = struct.unpack_from("<i", data, 0)[0]
def on_message_received(self, msg: IncomingMessage) -> None:
val = msg.read_int32()
data = bytearray()
data += struct.pack("<i", value)
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.write_int32(value)
super().queue_message_to_send(msg)
def test_int_channel():

messages = receiver.get_and_clear_received_messages()
assert len(messages) == 0
def test_message_bool():
vals = [True, False]
msg_out = OutgoingMessage()
for v in vals:
msg_out.write_bool(v)
msg_in = IncomingMessage(msg_out.buffer)
read_vals = []
for _ in range(len(vals)):
read_vals.append(msg_in.read_bool())
assert vals == read_vals
def test_message_int32():
val = 1337
msg_out = OutgoingMessage()
msg_out.write_int32(val)
msg_in = IncomingMessage(msg_out.buffer)
read_val = msg_in.read_int32()
assert val == read_val
def test_message_float32():
val = 42.0
msg_out = OutgoingMessage()
msg_out.write_float32(val)
msg_in = IncomingMessage(msg_out.buffer)
read_val = msg_in.read_float32()
# These won't be exactly equal in general, since python floats are 64-bit.
assert val == read_val
def test_message_string():
val = "mlagents!"
msg_out = OutgoingMessage()
msg_out.write_string(val)
msg_in = IncomingMessage(msg_out.buffer)
read_val = msg_in.read_string()
assert val == read_val
def test_message_float_list():
val = [1.0, 3.0, 9.0]
msg_out = OutgoingMessage()
msg_out.write_float32_list(val)
msg_in = IncomingMessage(msg_out.buffer)
read_val = msg_in.read_float32_list()
# These won't be exactly equal in general, since python floats are 64-bit.
assert val == read_val

101
com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs


using System.Collections.Generic;
using System;
using System.IO;
using System.Text;
namespace MLAgents.SideChannels
{
/// <summary>
/// Utility class for reading the data sent to the SideChannel.
/// </summary>
public class IncomingMessage : IDisposable
{
byte[] m_Data;
Stream m_Stream;
BinaryReader m_Reader;
/// <summary>
/// Construct an IncomingMessage from the byte array.
/// </summary>
/// <param name="data"></param>
public IncomingMessage(byte[] data)
{
m_Data = data;
m_Stream = new MemoryStream(data);
m_Reader = new BinaryReader(m_Stream);
}
/// <summary>
/// Read a boolan value from the message.
/// </summary>
/// <returns></returns>
public bool ReadBoolean()
{
return m_Reader.ReadBoolean();
}
/// <summary>
/// Read an integer value from the message.
/// </summary>
/// <returns></returns>
public int ReadInt32()
{
return m_Reader.ReadInt32();
}
/// <summary>
/// Read a float value from the message.
/// </summary>
/// <returns></returns>
public float ReadFloat32()
{
return m_Reader.ReadSingle();
}
/// <summary>
/// Read a string value from the message.
/// </summary>
/// <returns></returns>
public string ReadString()
{
var strLength = ReadInt32();
var str = Encoding.ASCII.GetString(m_Reader.ReadBytes(strLength));
return str;
}
/// <summary>
/// Reads a list of floats from the message. The length of the list is stored in the message.
/// </summary>
/// <returns></returns>
public IList<float> ReadFloatList()
{
var len = ReadInt32();
var output = new float[len];
for (var i = 0; i < len; i++)
{
output[i] = ReadFloat32();
}
return output;
}
/// <summary>
/// Gets the original data of the message. Note that this will return all of the data,
/// even if part of it has already been read.
/// </summary>
/// <returns></returns>
public byte[] GetRawBytes()
{
return m_Data;
}
/// <summary>
/// Clean up the internal storage.
/// </summary>
public void Dispose()
{
m_Reader?.Dispose();
m_Stream?.Dispose();
}
}
}

11
com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta


fileFormatVersion: 2
guid: c8043cec65aeb4ec09db1d25ad694328
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

110
com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs


using System.Collections.Generic;
using System;
using System.IO;
using System.Text;
namespace MLAgents.SideChannels
{
/// <summary>
/// Utility class for forming the data that is sent to the SideChannel.
/// </summary>
public class OutgoingMessage : IDisposable
{
BinaryWriter m_Writer;
MemoryStream m_Stream;
/// <summary>
/// Create a new empty OutgoingMessage.
/// </summary>
public OutgoingMessage()
{
m_Stream = new MemoryStream();
m_Writer = new BinaryWriter(m_Stream);
}
/// <summary>
/// Clean up the internal storage.
/// </summary>
public void Dispose()
{
m_Writer?.Dispose();
m_Stream?.Dispose();
}
/// <summary>
/// Write a boolean value to the message.
/// </summary>
/// <param name="b"></param>
public void WriteBoolean(bool b)
{
m_Writer.Write(b);
}
/// <summary>
/// Write an interger value to the message.
/// </summary>
/// <param name="i"></param>
public void WriteInt32(int i)
{
m_Writer.Write(i);
}
/// <summary>
/// Write a float values to the message.
/// </summary>
/// <param name="f"></param>
public void WriteFloat32(float f)
{
m_Writer.Write(f);
}
/// <summary>
/// Write a string value to the message.
/// </summary>
/// <param name="s"></param>
public void WriteString(string s)
{
var stringEncoded = Encoding.ASCII.GetBytes(s);
m_Writer.Write(stringEncoded.Length);
m_Writer.Write(stringEncoded);
}
/// <summary>
/// Write a list or array of floats to the message.
/// </summary>
/// <param name="floatList"></param>
public void WriteFloatList(IList<float> floatList)
{
WriteInt32(floatList.Count);
foreach (var f in floatList)
{
WriteFloat32(f);
}
}
/// <summary>
/// Overwrite the message with a specific byte array.
/// </summary>
/// <param name="data"></param>
public void SetRawBytes(byte[] data)
{
// Reset first. Set the length to zero so that if there's more data than we're going to
// write, we don't have any of the original data.
m_Stream.Seek(0, SeekOrigin.Begin);
m_Stream.SetLength(0);
// Then append the data. Increase the capacity if needed (but don't shrink it).
m_Stream.Capacity = (m_Stream.Capacity < data.Length) ? data.Length : m_Stream.Capacity;
m_Stream.Write(data, 0, data.Length);
}
/// <summary>
/// Read the byte array of the message.
/// </summary>
/// <returns></returns>
internal byte[] ToByteArray()
{
return m_Stream.ToArray();
}
}
}

11
com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta


fileFormatVersion: 2
guid: 1a007135a9a1e49849eb2d295f4c3879
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

65
ml-agents-envs/mlagents_envs/side_channel/incoming_message.py


from typing import List
import struct
class IncomingMessage:
"""
Utility class for reading the message written to a SideChannel.
Values must be read in the order they were written.
"""
def __init__(self, buffer: bytes, offset: int = 0):
"""
Create a new IncomingMessage from the bytes.
"""
self.buffer = buffer
self.offset = offset
def read_bool(self) -> bool:
"""
Read a boolean value from the message buffer.
"""
val = struct.unpack_from("<?", self.buffer, self.offset)[0]
self.offset += 1
return val
def read_int32(self) -> int:
"""
Read an integer value from the message buffer.
"""
val = struct.unpack_from("<i", self.buffer, self.offset)[0]
self.offset += 4
return val
def read_float32(self) -> float:
"""
Read a float value from the message buffer.
"""
val = struct.unpack_from("<f", self.buffer, self.offset)[0]
self.offset += 4
return val
def read_float32_list(self) -> List[float]:
"""
Read a list of float values from the message buffer.
"""
list_len = self.read_int32()
output = []
for _ in range(list_len):
output.append(self.read_float32())
return output
def read_string(self) -> str:
"""
Read a string value from the message buffer.
"""
encoded_str_len = self.read_int32()
val = self.buffer[self.offset : self.offset + encoded_str_len].decode("ascii")
self.offset += encoded_str_len
return val
def get_raw_bytes(self) -> bytes:
"""
Get a copy of the internal bytes used by the message.
"""
return bytearray(self.buffer)

66
ml-agents-envs/mlagents_envs/side_channel/outgoing_message.py


from typing import List
import struct
import logging
logger = logging.getLogger(__name__)
class OutgoingMessage:
"""
Utility class for forming the message that is written to a SideChannel.
All data is written in little-endian format using the struct module.
"""
def __init__(self):
"""
Create an OutgoingMessage with an empty buffer.
"""
self.buffer = bytearray()
def write_bool(self, b: bool) -> None:
"""
Append a boolean value.
"""
self.buffer += struct.pack("<?", b)
def write_int32(self, i: int) -> None:
"""
Append an integer value.
"""
self.buffer += struct.pack("<i", i)
def write_float32(self, f: float) -> None:
"""
Append a float value. It will be truncated to 32-bit precision.
"""
self.buffer += struct.pack("<f", f)
def write_float32_list(self, float_list: List[float]) -> None:
"""
Append a list of float values. They will be truncated to 32-bit precision.
"""
self.write_int32(len(float_list))
for f in float_list:
self.write_float32(f)
def write_string(self, s: str) -> None:
"""
Append a string value. Internally, it will be encoded to ascii, and the
encoded length will also be written to the message.
"""
encoded_key = s.encode("ascii")
self.write_int32(len(encoded_key))
self.buffer += encoded_key
def set_raw_bytes(self, buffer: bytearray) -> None:
"""
Set the internal buffer to a new bytearray. This will overwrite any existing data.
:param buffer:
:return:
"""
if self.buffer:
logger.warning(
"Called set_raw_bytes but the message already has been written to. This will overwrite data."
)
self.buffer = bytearray(buffer)
正在加载...
取消
保存