浏览代码
Develop side channel (#2956)
Develop side channel (#2956)
* [WIP] Side Channel initial layout * Working prototype for raw bytes * fixing format mistake * Added some errors and some unit tests in C# * Added the side channel for the Engine Configuration. (#2958) * Added the side channel for the Engine Configuration. Note that this change does not require modifying a lot of files : - Adding a sender in Python - Adding a receiver in C# - subscribe the receiver to the communicator (here is a one liner in the Academy) - Add the side channel to the Python UnityEnvironment (not represented here) Adding the side channel to the environment would look like such : ```python from mlagents.envs.environment import UnityEnvironment from mlagents.envs.side_channel.raw_bytes_channel import RawBytesChannel from mlagents.envs.side_channel.engine_configuration_channel import EngineConfigurationChannel channel0 = RawBytesChannel() channel1 = EngineConfigurationChannel() env = UnityEnvironme.../develop/tanhsquash
GitHub
5 年前
当前提交
11243348
共有 31 个文件被更改,包括 1047 次插入 和 32 次删除
-
7UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
-
45UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs
-
44UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs
-
104UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
-
7UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
-
19ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py
-
6ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi
-
19ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py
-
6ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi
-
58ml-agents-envs/mlagents/envs/environment.py
-
1protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_input.proto
-
1protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_output.proto
-
108UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs
-
11UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta
-
8UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta
-
91ml-agents-envs/mlagents/envs/tests/test_side_channel.py
-
36UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta
-
123UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta
-
65UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta
-
49UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta
-
0ml-agents-envs/mlagents/envs/side_channel/__init__.py
-
61ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py
-
74ml-agents-envs/mlagents/envs/side_channel/float_properties_channel.py
-
41ml-agents-envs/mlagents/envs/side_channel/raw_bytes_channel.py
-
51ml-agents-envs/mlagents/envs/side_channel/side_channel.py
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using MLAgents; |
|||
using System.Collections.Generic; |
|||
using System.Text; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class SideChannelTests |
|||
{ |
|||
|
|||
// This test side channel only deals in integers
|
|||
public class TestSideChannel : SideChannel |
|||
{ |
|||
|
|||
public List<int> m_MessagesReceived = new List<int>(); |
|||
|
|||
public override int ChannelType() { return -1; } |
|||
|
|||
public override void OnMessageReceived(byte[] data) |
|||
{ |
|||
m_MessagesReceived.Add(BitConverter.ToInt32(data, 0)); |
|||
} |
|||
|
|||
public void SendInt(int data) |
|||
{ |
|||
QueueMessageToSend(BitConverter.GetBytes(data)); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestIntegerSideChannel() |
|||
{ |
|||
var intSender = new TestSideChannel(); |
|||
var intReceiver = new TestSideChannel(); |
|||
var dictSender = new Dictionary<int, SideChannel> { { intSender.ChannelType(), intSender } }; |
|||
var dictReceiver = new Dictionary<int, SideChannel> { { intReceiver.ChannelType(), intReceiver } }; |
|||
|
|||
intSender.SendInt(4); |
|||
intSender.SendInt(5); |
|||
intSender.SendInt(6); |
|||
|
|||
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); |
|||
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); |
|||
|
|||
Assert.AreEqual(intReceiver.m_MessagesReceived[0], 4); |
|||
Assert.AreEqual(intReceiver.m_MessagesReceived[1], 5); |
|||
Assert.AreEqual(intReceiver.m_MessagesReceived[2], 6); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRawBytesSideChannel() |
|||
{ |
|||
var str1 = "Test string"; |
|||
var str2 = "Test string, second"; |
|||
|
|||
var strSender = new RawBytesChannel(); |
|||
var strReceiver = new RawBytesChannel(); |
|||
var dictSender = new Dictionary<int, SideChannel> { { strSender.ChannelType(), strSender } }; |
|||
var dictReceiver = new Dictionary<int, SideChannel> { { strReceiver.ChannelType(), strReceiver } }; |
|||
|
|||
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1)); |
|||
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2)); |
|||
|
|||
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); |
|||
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); |
|||
|
|||
var messages = strReceiver.GetAndClearReceivedMessages(); |
|||
|
|||
Assert.AreEqual(messages.Count, 2); |
|||
Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1); |
|||
Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestFloatPropertiesSideChannel() |
|||
{ |
|||
var k1 = "gravity"; |
|||
var k2 = "length"; |
|||
int wasCalled = 0; |
|||
|
|||
var propA = new FloatPropertiesChannel(); |
|||
var propB = new FloatPropertiesChannel(); |
|||
var dictReceiver = new Dictionary<int, SideChannel> { { propA.ChannelType(), propA } }; |
|||
var dictSender = new Dictionary<int, SideChannel> { { propB.ChannelType(), propB } }; |
|||
|
|||
propA.RegisterCallback(k1, f => { wasCalled++; }); |
|||
var tmp = propB.GetPropertyWithDefault(k2, 3.0f); |
|||
Assert.AreEqual(tmp, 3.0f); |
|||
propB.SetProperty(k2, 1.0f); |
|||
tmp = propB.GetPropertyWithDefault(k2, 3.0f); |
|||
Assert.AreEqual(tmp, 1.0f); |
|||
|
|||
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); |
|||
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); |
|||
|
|||
tmp = propA.GetPropertyWithDefault(k2, 3.0f); |
|||
Assert.AreEqual(tmp, 1.0f); |
|||
|
|||
Assert.AreEqual(wasCalled, 0); |
|||
propB.SetProperty(k1, 1.0f); |
|||
Assert.AreEqual(wasCalled, 0); |
|||
fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); |
|||
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); |
|||
Assert.AreEqual(wasCalled, 1); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 589f475debcdb479295a24799777b5e5 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: cb2f03ed7ea59456380730bd0f9b5bcb |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import struct |
|||
from mlagents.envs.side_channel.side_channel import SideChannel |
|||
from mlagents.envs.side_channel.float_properties_channel import FloatPropertiesChannel |
|||
from mlagents.envs.side_channel.raw_bytes_channel import RawBytesChannel |
|||
from mlagents.envs.environment import UnityEnvironment |
|||
|
|||
|
|||
class IntChannel(SideChannel): |
|||
def __init__(self): |
|||
self.list_int = [] |
|||
super().__init__() |
|||
|
|||
@property |
|||
def channel_type(self): |
|||
return -1 |
|||
|
|||
def on_message_received(self, data): |
|||
val = struct.unpack_from("<i", data, 0)[0] |
|||
self.list_int += [val] |
|||
|
|||
def send_int(self, value): |
|||
data = bytearray() |
|||
data += struct.pack("<i", value) |
|||
super().queue_message_to_send(data) |
|||
|
|||
|
|||
def test_int_channel(): |
|||
sender = IntChannel() |
|||
receiver = IntChannel() |
|||
sender.send_int(5) |
|||
sender.send_int(6) |
|||
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender}) |
|||
UnityEnvironment._parse_side_channel_message( |
|||
{receiver.channel_type: receiver}, data |
|||
) |
|||
assert receiver.list_int[0] == 5 |
|||
assert receiver.list_int[1] == 6 |
|||
|
|||
|
|||
def test_float_properties(): |
|||
sender = FloatPropertiesChannel() |
|||
receiver = FloatPropertiesChannel() |
|||
|
|||
sender.set_property("prop1", 1.0) |
|||
|
|||
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender}) |
|||
UnityEnvironment._parse_side_channel_message( |
|||
{receiver.channel_type: receiver}, data |
|||
) |
|||
|
|||
val = receiver.get_property("prop1") |
|||
assert val == 1.0 |
|||
val = receiver.get_property("prop2") |
|||
assert val is None |
|||
sender.set_property("prop2", 2.0) |
|||
|
|||
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender}) |
|||
UnityEnvironment._parse_side_channel_message( |
|||
{receiver.channel_type: receiver}, data |
|||
) |
|||
|
|||
val = receiver.get_property("prop1") |
|||
assert val == 1.0 |
|||
val = receiver.get_property("prop2") |
|||
assert val == 2.0 |
|||
assert len(receiver.list_properties()) == 2 |
|||
assert "prop1" in receiver.list_properties() |
|||
assert "prop2" in receiver.list_properties() |
|||
val = sender.get_property("prop1") |
|||
assert val == 1.0 |
|||
|
|||
|
|||
def test_raw_bytes(): |
|||
sender = RawBytesChannel() |
|||
receiver = RawBytesChannel() |
|||
|
|||
sender.send_raw_data("foo".encode("ascii")) |
|||
sender.send_raw_data("bar".encode("ascii")) |
|||
|
|||
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender}) |
|||
UnityEnvironment._parse_side_channel_message( |
|||
{receiver.channel_type: receiver}, data |
|||
) |
|||
|
|||
messages = receiver.get_and_clear_received_messages() |
|||
assert len(messages) == 2 |
|||
assert messages[0].decode("ascii") == "foo" |
|||
assert messages[1].decode("ascii") == "bar" |
|||
|
|||
messages = receiver.get_and_clear_received_messages() |
|||
assert len(messages) == 0 |
|
|||
using System.Collections.Generic; |
|||
using System.IO; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
public class EngineConfigurationChannel : SideChannel |
|||
{ |
|||
|
|||
public override int ChannelType() |
|||
{ |
|||
return (int)SideChannelType.EngineSettings; |
|||
} |
|||
|
|||
public override void OnMessageReceived(byte[] data) |
|||
{ |
|||
using (var memStream = new MemoryStream(data)) |
|||
{ |
|||
using (var binaryReader = new BinaryReader(memStream)) |
|||
{ |
|||
var width = binaryReader.ReadInt32(); |
|||
var height = binaryReader.ReadInt32(); |
|||
var qualityLevel = binaryReader.ReadInt32(); |
|||
var timeScale = binaryReader.ReadSingle(); |
|||
var targetFrameRate = binaryReader.ReadInt32(); |
|||
|
|||
Screen.SetResolution(width, height, false); |
|||
QualitySettings.SetQualityLevel(qualityLevel, true); |
|||
Time.timeScale = timeScale; |
|||
Time.captureFramerate = 60; |
|||
Application.targetFrameRate = targetFrameRate; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 18ccdf3ce76784f2db68016fa284c33f |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using System.IO; |
|||
using System; |
|||
using System.Text; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
|
|||
public interface IFloatProperties |
|||
{ |
|||
/// <summary>
|
|||
/// Sets one of the float properties of the environment. This data will be sent to Python.
|
|||
/// </summary>
|
|||
/// <param name="key"> The string identifier of the property.</param>
|
|||
/// <param name="value"> The float value of the property.</param>
|
|||
void SetProperty(string key, float value); |
|||
|
|||
/// <summary>
|
|||
/// Get an Environment property with a default value. If there is a value for this property,
|
|||
/// it will be returned, otherwise, the default value will be returned.
|
|||
/// </summary>
|
|||
/// <param name="key"> The string identifier of the property.</param>
|
|||
/// <param name="defaultValue"> The default value of the property.</param>
|
|||
/// <returns></returns>
|
|||
float GetPropertyWithDefault(string key, float defaultValue); |
|||
|
|||
/// <summary>
|
|||
/// Registers an action to be performed everytime the property is changed.
|
|||
/// </summary>
|
|||
/// <param name="key"> The string identifier of the property.</param>
|
|||
/// <param name="action"> The action that ill be performed. Takes a float as input.</param>
|
|||
void RegisterCallback(string key, Action<float> action); |
|||
|
|||
/// <summary>
|
|||
/// Returns a list of all the string identifiers of the properties currently present.
|
|||
/// </summary>
|
|||
/// <returns> The list of string identifiers </returns>
|
|||
IList<string> ListProperties(); |
|||
} |
|||
|
|||
public class FloatPropertiesChannel : SideChannel, IFloatProperties |
|||
{ |
|||
|
|||
private Dictionary<string, float> m_FloatProperties = new Dictionary<string, float>(); |
|||
private Dictionary<string, Action<float>> m_RegisteredActions = new Dictionary<string, Action<float>>(); |
|||
|
|||
public override int ChannelType() |
|||
{ |
|||
return (int)SideChannelType.FloatProperties; |
|||
} |
|||
|
|||
public override void OnMessageReceived(byte[] data) |
|||
{ |
|||
var kv = DeserializeMessage(data); |
|||
m_FloatProperties[kv.Key] = kv.Value; |
|||
if (m_RegisteredActions.ContainsKey(kv.Key)) |
|||
{ |
|||
m_RegisteredActions[kv.Key].Invoke(kv.Value); |
|||
} |
|||
} |
|||
|
|||
public void SetProperty(string key, float value) |
|||
{ |
|||
m_FloatProperties[key] = value; |
|||
QueueMessageToSend(SerializeMessage(key, value)); |
|||
if (m_RegisteredActions.ContainsKey(key)) |
|||
{ |
|||
m_RegisteredActions[key].Invoke(value); |
|||
} |
|||
} |
|||
|
|||
public float GetPropertyWithDefault(string key, float defaultValue) |
|||
{ |
|||
if (m_FloatProperties.ContainsKey(key)) |
|||
{ |
|||
return m_FloatProperties[key]; |
|||
} |
|||
else |
|||
{ |
|||
return defaultValue; |
|||
} |
|||
} |
|||
|
|||
public void RegisterCallback(string key, Action<float> action) |
|||
{ |
|||
m_RegisteredActions[key] = action; |
|||
} |
|||
|
|||
public IList<string> ListProperties() |
|||
{ |
|||
return new List<string>(m_FloatProperties.Keys); |
|||
} |
|||
|
|||
private static KeyValuePair<string, float> DeserializeMessage(byte[] data) |
|||
{ |
|||
using (var memStream = new MemoryStream(data)) |
|||
{ |
|||
using (var binaryReader = new BinaryReader(memStream)) |
|||
{ |
|||
var keyLength = binaryReader.ReadInt32(); |
|||
var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength)); |
|||
var value = binaryReader.ReadSingle(); |
|||
return new KeyValuePair<string, float>(key, value); |
|||
} |
|||
} |
|||
} |
|||
|
|||
private static byte[] SerializeMessage(string key, float value) |
|||
{ |
|||
using (var memStream = new MemoryStream()) |
|||
{ |
|||
using (var binaryWriter = new BinaryWriter(memStream)) |
|||
{ |
|||
var stringEncoded = Encoding.ASCII.GetBytes(key); |
|||
binaryWriter.Write(stringEncoded.Length); |
|||
binaryWriter.Write(stringEncoded); |
|||
binaryWriter.Write(value); |
|||
return memStream.ToArray(); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 452f8b3c01c4642aba645dcf0b6bfc6e |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
namespace MLAgents |
|||
{ |
|||
public class RawBytesChannel : SideChannel |
|||
{ |
|||
|
|||
private List<byte[]> m_MessagesReceived = new List<byte[]>(); |
|||
private int m_ChannelId; |
|||
|
|||
/// <summary>
|
|||
/// RawBytesChannel provides a way to exchange raw byte arrays between Unity and Python.
|
|||
/// </summary>
|
|||
/// <param name="channelId"> The identifier for the RawBytesChannel. Must be
|
|||
/// the same on Python and Unity.</param>
|
|||
public RawBytesChannel(int channelId = 0) |
|||
{ |
|||
m_ChannelId = channelId; |
|||
} |
|||
public override int ChannelType() |
|||
{ |
|||
return (int)SideChannelType.RawBytesChannelStart + m_ChannelId; |
|||
} |
|||
|
|||
public override void OnMessageReceived(byte[] data) |
|||
{ |
|||
m_MessagesReceived.Add(data); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Sends the byte array message to the Python side channel. The message will be sent
|
|||
/// alongside the simulation step.
|
|||
/// </summary>
|
|||
/// <param name="data"> The byte array of data to send to Python.</param>
|
|||
public void SendRawBytes(byte[] data) |
|||
{ |
|||
QueueMessageToSend(data); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the messages that were sent by python since the last call to
|
|||
/// GetAndClearReceivedMessages.
|
|||
/// </summary>
|
|||
/// <returns> a list of byte array messages that Python has sent.</returns>
|
|||
public IList<byte[]> GetAndClearReceivedMessages() |
|||
{ |
|||
var result = new List<byte[]>(); |
|||
result.AddRange(m_MessagesReceived); |
|||
m_MessagesReceived.Clear(); |
|||
return result; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the messages that were sent by python since the last call to
|
|||
/// GetAndClearReceivedMessages. Note that the messages received will not
|
|||
/// be cleared with a call to GetReceivedMessages.
|
|||
/// </summary>
|
|||
/// <returns> a list of byte array messages that Python has sent.</returns>
|
|||
public IList<byte[]> GetReceivedMessages() |
|||
{ |
|||
var result = new List<byte[]>(); |
|||
result.AddRange(m_MessagesReceived); |
|||
return result; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 40b01e9cdbfd94865b54ebeb4e5aeaa5 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
public enum SideChannelType |
|||
{ |
|||
// Invalid side channel
|
|||
Invalid = 0, |
|||
// Reserved for the FloatPropertiesChannel.
|
|||
FloatProperties = 1, |
|||
//Reserved for the EngineConfigurationChannel.
|
|||
EngineSettings = 2, |
|||
// Raw bytes channels should start here to avoid conflicting with other Unity ones.
|
|||
RawBytesChannelStart = 1000, |
|||
// custom side channels should start here to avoid conflicting with Unity ones.
|
|||
UserSideChannelStart = 2000, |
|||
} |
|||
|
|||
public abstract class SideChannel |
|||
{ |
|||
// The list of messages (byte arrays) that need to be sent to Python via the communicator.
|
|||
// Should only ever be read and cleared by a ICommunicator object.
|
|||
public List<byte[]> MessageQueue = new List<byte[]>(); |
|||
|
|||
/// <summary>
|
|||
/// An int identifier for the SideChannel. Ensures that there is only ever one side channel
|
|||
/// of each type. Ensure the Unity side channels will be linked to their Python equivalent.
|
|||
/// </summary>
|
|||
/// <returns> The integer identifier of the SideChannel</returns>
|
|||
public abstract int ChannelType(); |
|||
|
|||
/// <summary>
|
|||
/// Is called by the communicator every time a message is received from Python by the SideChannel.
|
|||
/// Can be called multiple times per simulation step if multiple messages were sent.
|
|||
/// </summary>
|
|||
/// <param name="data"> the payload of the message.</param>
|
|||
public abstract void OnMessageReceived(byte[] data); |
|||
|
|||
/// <summary>
|
|||
/// Queues a message to be sent to Python during the next simulation step.
|
|||
/// </summary>
|
|||
/// <param name="data"> The byte array of data to be sent to Python.</param>
|
|||
protected void QueueMessageToSend(byte[] data) |
|||
{ |
|||
MessageQueue.Add(data); |
|||
} |
|||
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 77b7d19dd6ce343eeba907540b5a2286 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType |
|||
from mlagents.envs.exception import UnityCommunicationException |
|||
import struct |
|||
|
|||
|
|||
class EngineConfigurationChannel(SideChannel): |
|||
""" |
|||
This is the SideChannel for engine configuration exchange. The data in the |
|||
engine configuration is as follows : |
|||
- int width; |
|||
- int height; |
|||
- int qualityLevel; |
|||
- float timeScale; |
|||
- int targetFrameRate; |
|||
""" |
|||
|
|||
@property |
|||
def channel_type(self) -> int: |
|||
return SideChannelType.EngineSettings |
|||
|
|||
def on_message_received(self, data: bytearray) -> None: |
|||
""" |
|||
Is called by the environment to the side channel. Can be called |
|||
multiple times per step if multiple messages are meant for that |
|||
SideChannel. |
|||
Note that Python should never receive an engine configuration from |
|||
Unity |
|||
""" |
|||
raise UnityCommunicationException( |
|||
"The EngineConfigurationChannel received a message from Unity, " |
|||
+ "this should not have happend." |
|||
) |
|||
|
|||
def set_configuration( |
|||
self, |
|||
width: int = 80, |
|||
height: int = 80, |
|||
quality_level: int = 1, |
|||
time_scale: float = 20.0, |
|||
target_frame_rate: int = -1, |
|||
) -> None: |
|||
""" |
|||
Sets the engine configuration. Takes as input the configurations of the |
|||
engine. |
|||
:param width: Defines the width of the display. Default 80. |
|||
:param height: Defines the height of the display. Default 80. |
|||
:param quality_level: Defines the quality level of the simulation. |
|||
Default 1. |
|||
:param time_scale: Defines the multiplier for the deltatime in the |
|||
simulation. If set to a higher value, time will pass faaster in the |
|||
simulation but the physics might break. Default 20. |
|||
:param target_frame_rate: Instructs simulation to try to render at a |
|||
specified frame rate. Default -1. |
|||
""" |
|||
data = bytearray() |
|||
data += struct.pack("<i", width) |
|||
data += struct.pack("<i", height) |
|||
data += struct.pack("<i", quality_level) |
|||
data += struct.pack("<f", time_scale) |
|||
data += struct.pack("<i", target_frame_rate) |
|||
super().queue_message_to_send(data) |
|
|||
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType |
|||
import struct |
|||
from typing import Tuple, Optional, List |
|||
|
|||
|
|||
class FloatPropertiesChannel(SideChannel): |
|||
""" |
|||
This is the SideChannel for float properties shared with Unity. |
|||
You can modify the float properties of an environment with the commands |
|||
set_property, get_property and list_properties. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self._float_properties = {} |
|||
super().__init__() |
|||
|
|||
@property |
|||
def channel_type(self) -> int: |
|||
return SideChannelType.FloatProperties |
|||
|
|||
def on_message_received(self, data: bytearray) -> None: |
|||
""" |
|||
Is called by the environment to the side channel. Can be called |
|||
multiple times per step if multiple messages are meant for that |
|||
SideChannel. |
|||
Note that Python should never receive an engine configuration from |
|||
Unity |
|||
""" |
|||
k, v = self.deserialize_float_prop(data) |
|||
self._float_properties[k] = v |
|||
|
|||
def set_property(self, key: str, value: float) -> None: |
|||
""" |
|||
Sets a property in the Unity Environment. |
|||
:param key: The string identifier of the property. |
|||
:param value: The float value of the property. |
|||
""" |
|||
self._float_properties[key] = value |
|||
super().queue_message_to_send(self.serialize_float_prop(key, value)) |
|||
|
|||
def get_property(self, key: str) -> Optional[float]: |
|||
""" |
|||
Gets a property in the Unity Environment. If the property was not |
|||
found, will return None. |
|||
:param key: The string identifier of the property. |
|||
:return: The float value of the property or None. |
|||
""" |
|||
return self._float_properties.get(key) |
|||
|
|||
def list_properties(self) -> List[str]: |
|||
""" |
|||
Returns a list of all the string identifiers of the properties |
|||
currently present in the Unity Environment. |
|||
""" |
|||
return self._float_properties.keys() |
|||
|
|||
@staticmethod |
|||
def serialize_float_prop(key: str, value: float) -> bytearray: |
|||
result = bytearray() |
|||
encoded_key = key.encode("ascii") |
|||
result += struct.pack("<i", len(encoded_key)) |
|||
result += encoded_key |
|||
result += struct.pack("<f", value) |
|||
return result |
|||
|
|||
@staticmethod |
|||
def deserialize_float_prop(data: bytearray) -> Tuple[str, float]: |
|||
offset = 0 |
|||
encoded_key_len = struct.unpack_from("<i", data, offset)[0] |
|||
offset = offset + 4 |
|||
key = data[offset : offset + encoded_key_len].decode("ascii") |
|||
offset = offset + encoded_key_len |
|||
value = struct.unpack_from("<f", data, offset)[0] |
|||
return key, value |
|
|||
from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType |
|||
from typing import List |
|||
|
|||
|
|||
class RawBytesChannel(SideChannel): |
|||
""" |
|||
This is an example of what the SideChannel for raw bytes exchange would |
|||
look like. Is meant to be used for general research purpose. |
|||
""" |
|||
|
|||
def __init__(self, channel_id=0): |
|||
self._received_messages = [] |
|||
self._channel_id = channel_id |
|||
super().__init__() |
|||
|
|||
@property |
|||
def channel_type(self) -> int: |
|||
return SideChannelType.RawBytesChannelStart + self._channel_id |
|||
|
|||
def on_message_received(self, data: bytearray) -> None: |
|||
""" |
|||
Is called by the environment to the side channel. Can be called |
|||
multiple times per step if multiple messages are meant for that |
|||
SideChannel. |
|||
""" |
|||
self._received_messages.append(data) |
|||
|
|||
def get_and_clear_received_messages(self) -> List[bytearray]: |
|||
""" |
|||
returns a list of bytearray received from the environment. |
|||
""" |
|||
result = list(self._received_messages) |
|||
self._received_messages = [] |
|||
return result |
|||
|
|||
def send_raw_data(self, data: bytearray) -> None: |
|||
""" |
|||
Queues a message to be sent by the environment at the next call to |
|||
step. |
|||
""" |
|||
super().queue_message_to_send(data) |
|
|||
from abc import ABC, abstractmethod |
|||
from enum import IntEnum |
|||
|
|||
|
|||
class SideChannelType(IntEnum): |
|||
FloatProperties = 1 |
|||
EngineSettings = 2 |
|||
# Raw bytes channels should start here to avoid conflicting with other |
|||
# Unity ones. |
|||
RawBytesChannelStart = 1000 |
|||
# custom side channels should start here to avoid conflicting with Unity |
|||
# ones. |
|||
UserSideChannelStart = 2000 |
|||
|
|||
|
|||
class SideChannel(ABC): |
|||
""" |
|||
The side channel just get access to a bytes buffer that will be shared |
|||
between C# and Python. For example, We will create a specific side channel |
|||
for properties that will be a list of string (fixed size) to float number, |
|||
that can be modified by both C# and Python. All side channels are passed |
|||
to the Env object at construction. |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self.message_queue = [] |
|||
|
|||
def queue_message_to_send(self, data: bytearray) -> None: |
|||
""" |
|||
Queues a message to be sent by the environment at the next call to |
|||
step. |
|||
""" |
|||
self.message_queue.append(data) |
|||
|
|||
@abstractmethod |
|||
def on_message_received(self, data: bytearray) -> None: |
|||
""" |
|||
Is called by the environment to the side channel. Can be called |
|||
multiple times per step if multiple messages are meant for that |
|||
SideChannel. |
|||
""" |
|||
pass |
|||
|
|||
@property |
|||
@abstractmethod |
|||
def channel_type(self) -> int: |
|||
""" |
|||
:return:The type of side channel used. Will influence how the data is |
|||
processed in the environment. |
|||
""" |
|||
pass |
撰写
预览
正在加载...
取消
保存
Reference in new issue