Chris Elion
5 年前
当前提交
c57ff268
共有 10 个文件被更改,包括 372 次插入 和 328 次删除
-
196com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs
-
4ml-agents-envs/mlagents_envs/side_channel/__init__.py
-
131ml-agents-envs/mlagents_envs/side_channel/side_channel.py
-
6ml-agents-envs/mlagents_envs/tests/test_side_channel.py
-
101com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs
-
11com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta
-
109com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs
-
11com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta
-
65ml-agents-envs/mlagents_envs/side_channel/incoming_message.py
-
66ml-agents-envs/mlagents_envs/side_channel/outgoing_message.py
|
|||
from mlagents_envs.side_channel.incoming_message import IncomingMessage # noqa |
|||
from mlagents_envs.side_channel.outgoing_message import OutgoingMessage # noqa |
|||
|
|||
from mlagents_envs.side_channel.side_channel import SideChannel # noqa |
|
|||
using System.Collections.Generic; |
|||
using System; |
|||
using System.IO; |
|||
using System.Text; |
|||
|
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Utility class for reading the data sent to the SideChannel.
|
|||
/// </summary>
|
|||
public class IncomingMessage : IDisposable |
|||
{ |
|||
byte[] m_Data; |
|||
Stream m_Stream; |
|||
BinaryReader m_Reader; |
|||
|
|||
/// <summary>
|
|||
/// Construct an IncomingMessage from the byte array.
|
|||
/// </summary>
|
|||
/// <param name="data"></param>
|
|||
public IncomingMessage(byte[] data) |
|||
{ |
|||
m_Data = data; |
|||
m_Stream = new MemoryStream(data); |
|||
m_Reader = new BinaryReader(m_Stream); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read a boolan value from the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public bool ReadBoolean() |
|||
{ |
|||
return m_Reader.ReadBoolean(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read an integer value from the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public int ReadInt32() |
|||
{ |
|||
return m_Reader.ReadInt32(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read a float value from the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public float ReadFloat32() |
|||
{ |
|||
return m_Reader.ReadSingle(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read a string value from the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public string ReadString() |
|||
{ |
|||
var strLength = ReadInt32(); |
|||
var str = Encoding.ASCII.GetString(m_Reader.ReadBytes(strLength)); |
|||
return str; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Reads a list of floats from the message. The length of the list is stored in the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public IList<float> ReadFloatList() |
|||
{ |
|||
var len = ReadInt32(); |
|||
var output = new float[len]; |
|||
for (var i = 0; i < len; i++) |
|||
{ |
|||
output[i] = ReadFloat32(); |
|||
} |
|||
|
|||
return output; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the original data of the message. Note that this will return all of the data,
|
|||
/// even if part of it has already been read.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public byte[] GetRawBytes() |
|||
{ |
|||
return m_Data; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Clean up the internal storage.
|
|||
/// </summary>
|
|||
public void Dispose() |
|||
{ |
|||
m_Reader?.Dispose(); |
|||
m_Stream?.Dispose(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c8043cec65aeb4ec09db1d25ad694328 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using System; |
|||
using System.IO; |
|||
using System.Text; |
|||
|
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Utility class for forming the data that is sent to the SideChannel.
|
|||
/// </summary>
|
|||
public class OutgoingMessage : IDisposable |
|||
{ |
|||
BinaryWriter m_Writer; |
|||
MemoryStream m_Stream; |
|||
|
|||
/// <summary>
|
|||
/// Create a new empty OutgoingMessage.
|
|||
/// </summary>
|
|||
public OutgoingMessage() |
|||
{ |
|||
m_Stream = new MemoryStream(); |
|||
m_Writer = new BinaryWriter(m_Stream); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Clean up the internal storage.
|
|||
/// </summary>
|
|||
public void Dispose() |
|||
{ |
|||
m_Writer?.Dispose(); |
|||
m_Stream?.Dispose(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write a boolean value to the message.
|
|||
/// </summary>
|
|||
/// <param name="b"></param>
|
|||
public void WriteBoolean(bool b) |
|||
{ |
|||
m_Writer.Write(b); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write an interger value to the message.
|
|||
/// </summary>
|
|||
/// <param name="i"></param>
|
|||
public void WriteInt32(int i) |
|||
{ |
|||
m_Writer.Write(i); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write a float values to the message.
|
|||
/// </summary>
|
|||
/// <param name="f"></param>
|
|||
public void WriteFloat32(float f) |
|||
{ |
|||
m_Writer.Write(f); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write a string value to the message.
|
|||
/// </summary>
|
|||
/// <param name="s"></param>
|
|||
public void WriteString(string s) |
|||
{ |
|||
var stringEncoded = Encoding.ASCII.GetBytes(s); |
|||
m_Writer.Write(stringEncoded.Length); |
|||
m_Writer.Write(stringEncoded); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write a list or array of floats to the message.
|
|||
/// </summary>
|
|||
/// <param name="floatList"></param>
|
|||
public void WriteFloatList(IList<float> floatList) |
|||
{ |
|||
WriteInt32(floatList.Count); |
|||
foreach (var f in floatList) |
|||
{ |
|||
WriteFloat32(f); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Overwrite the message with a specific byte array.
|
|||
/// </summary>
|
|||
/// <param name="data"></param>
|
|||
public void SetRawBytes(byte[] data) |
|||
{ |
|||
// Reset first.
|
|||
m_Stream.Seek(0, SeekOrigin.Begin); |
|||
m_Stream.SetLength(0); |
|||
|
|||
// Then append the data
|
|||
m_Stream.Capacity = (m_Stream.Capacity < data.Length) ? data.Length : m_Stream.Capacity; |
|||
m_Stream.Write(data, 0, data.Length); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read the byte array of the message.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal byte[] ToByteArray() |
|||
{ |
|||
return m_Stream.ToArray(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 1a007135a9a1e49849eb2d295f4c3879 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from typing import List |
|||
import struct |
|||
|
|||
|
|||
class IncomingMessage: |
|||
""" |
|||
Utility class for reading the message written to a SideChannel. |
|||
Values must be read in the order they were written. |
|||
""" |
|||
|
|||
def __init__(self, buffer: bytes, offset: int = 0): |
|||
""" |
|||
Create a new IncomingMessage from the bytes. |
|||
""" |
|||
self.buffer = buffer |
|||
self.offset = offset |
|||
|
|||
def read_bool(self) -> bool: |
|||
""" |
|||
Read a boolean value from the message buffer. |
|||
""" |
|||
val = struct.unpack_from("<?", self.buffer, self.offset)[0] |
|||
self.offset += 1 |
|||
return val |
|||
|
|||
def read_int32(self) -> int: |
|||
""" |
|||
Read an integer value from the message buffer. |
|||
""" |
|||
val = struct.unpack_from("<i", self.buffer, self.offset)[0] |
|||
self.offset += 4 |
|||
return val |
|||
|
|||
def read_float32(self) -> float: |
|||
""" |
|||
Read a float value from the message buffer. |
|||
""" |
|||
val = struct.unpack_from("<f", self.buffer, self.offset)[0] |
|||
self.offset += 4 |
|||
return val |
|||
|
|||
def read_float32_list(self) -> List[float]: |
|||
""" |
|||
Read a list of float values from the message buffer. |
|||
""" |
|||
list_len = self.read_int32() |
|||
output = [] |
|||
for _ in range(list_len): |
|||
output.append(self.read_float32()) |
|||
return output |
|||
|
|||
def read_string(self) -> str: |
|||
""" |
|||
Read a string value from the message buffer. |
|||
""" |
|||
encoded_str_len = self.read_int32() |
|||
val = self.buffer[self.offset : self.offset + encoded_str_len].decode("ascii") |
|||
self.offset += encoded_str_len |
|||
return val |
|||
|
|||
def get_raw_bytes(self) -> bytes: |
|||
""" |
|||
Get a copy of the internal bytes used by the message. |
|||
""" |
|||
return bytearray(self.buffer) |
|
|||
from typing import List |
|||
import struct |
|||
|
|||
import logging |
|||
|
|||
logger = logging.getLogger(__name__) |
|||
|
|||
|
|||
class OutgoingMessage: |
|||
""" |
|||
Utility class for forming the message that is written to a SideChannel. |
|||
All data is written in little-endian format using the struct module. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
""" |
|||
Create an OutgoingMessage with an empty buffer. |
|||
""" |
|||
self.buffer = bytearray() |
|||
|
|||
def write_bool(self, b: bool) -> None: |
|||
""" |
|||
Append a boolean value. |
|||
""" |
|||
self.buffer += struct.pack("<?", b) |
|||
|
|||
def write_int32(self, i: int) -> None: |
|||
""" |
|||
Append an integer value. |
|||
""" |
|||
self.buffer += struct.pack("<i", i) |
|||
|
|||
def write_float32(self, f: float) -> None: |
|||
""" |
|||
Append a float value. It will be truncated to 32-bit precision. |
|||
""" |
|||
self.buffer += struct.pack("<f", f) |
|||
|
|||
def write_float32_list(self, float_list: List[float]) -> None: |
|||
""" |
|||
Append a list of float values. They will be truncated to 32-bit precision. |
|||
""" |
|||
self.write_int32(len(float_list)) |
|||
for f in float_list: |
|||
self.write_float32(f) |
|||
|
|||
def write_string(self, s: str) -> None: |
|||
""" |
|||
Append a string value. Internally, it will be encoded to ascii, and the |
|||
encoded length will also be written to the message. |
|||
""" |
|||
encoded_key = s.encode("ascii") |
|||
self.write_int32(len(encoded_key)) |
|||
self.buffer += encoded_key |
|||
|
|||
def set_raw_bytes(self, buffer: bytearray) -> None: |
|||
""" |
|||
Set the internal buffer to a new bytearray. This will overwrite any existing data. |
|||
:param buffer: |
|||
:return: |
|||
""" |
|||
if self.buffer: |
|||
logger.warning( |
|||
"Called set_raw_bytes but the message already has been written to. This will overwrite data." |
|||
) |
|||
self.buffer = bytearray(buffer) |
撰写
预览
正在加载...
取消
保存
Reference in new issue