浏览代码

Develop side channel (#2956)

* [WIP] Side Channel initial layout

* Working prototype for raw bytes

* fixing format mistake

* Added some errors and some unit tests in C#

* Added the side channel for the Engine Configuration. (#2958)

* Added the side channel for the Engine Configuration.

Note that this change does not require modifying a lot of files :
 - Adding a sender in Python
 - Adding a receiver in C#
 - subscribe the receiver to the communicator (here is a one liner in the Academy)
 - Add the side channel to the Python UnityEnvironment (not represented here)

Adding the side channel to the environment would look like such :

```python
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents.envs.side_channel.engine_configuration_channel import EngineConfigurationChannel

channel0 = RawBytesChannel()
channel1 = EngineConfigurationChannel()

env = UnityEnvironme...
/develop/tanhsquash
GitHub 5 年前
当前提交
11243348
共有 31 个文件被更改,包括 1047 次插入32 次删除
  1. 7
      UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
  2. 45
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs
  3. 44
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs
  4. 104
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
  5. 7
      UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
  6. 19
      ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py
  7. 6
      ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi
  8. 19
      ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py
  9. 6
      ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi
  10. 58
      ml-agents-envs/mlagents/envs/environment.py
  11. 1
      protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_input.proto
  12. 1
      protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_output.proto
  13. 108
      UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs
  14. 11
      UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta
  15. 8
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta
  16. 91
      ml-agents-envs/mlagents/envs/tests/test_side_channel.py
  17. 36
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs
  18. 11
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta
  19. 123
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs
  20. 11
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta
  21. 65
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs
  22. 11
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta
  23. 49
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs
  24. 11
      UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta
  25. 0
      ml-agents-envs/mlagents/envs/side_channel/__init__.py
  26. 61
      ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py
  27. 74
      ml-agents-envs/mlagents/envs/side_channel/float_properties_channel.py
  28. 41
      ml-agents-envs/mlagents/envs/side_channel/raw_bytes_channel.py
  29. 51
      ml-agents-envs/mlagents/envs/side_channel/side_channel.py

7
UnitySDK/Assets/ML-Agents/Scripts/Academy.cs


[Tooltip("List of custom parameters that can be changed in the " +
"environment when it resets.")]
public ResetParameters resetParameters;
public IFloatProperties FloatProperties;
public CommunicatorObjects.CustomResetParametersProto customResetParameters;
// Fields not provided in the Inspector.

m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;
InitializeAcademy();
var floatProperties = new FloatPropertiesChannel();
FloatProperties = floatProperties;
// Try to launch the communicator by using the arguments passed at launch
try

Communicator.QuitCommandReceived += OnQuitCommandReceived;
Communicator.ResetCommandReceived += OnResetCommand;
Communicator.RLInputReceived += OnRLInputReceived;
Communicator.RegisterSideChannel(new EngineConfigurationChannel());
Communicator.RegisterSideChannel(floatProperties);
}
}

45
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs


"ZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb24ucHJvdG8a",
"P21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvZW52aXJvbm1l",
"bnRfcGFyYW1ldGVycy5wcm90bxowbWxhZ2VudHMvZW52cy9jb21tdW5pY2F0",
"b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvIsMDChFVbml0eVJMSW5wdXRQcm90",
"b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvItkDChFVbml0eVJMSW5wdXRQcm90",
"cy5Db21tYW5kUHJvdG8aTQoUTGlzdEFnZW50QWN0aW9uUHJvdG8SNQoFdmFs",
"dWUYASADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5BZ2VudEFjdGlvblBy",
"b3RvGnEKEUFnZW50QWN0aW9uc0VudHJ5EgsKA2tleRgBIAEoCRJLCgV2YWx1",
"ZRgCIAEoCzI8LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbnB1dFBy",
"b3RvLkxpc3RBZ2VudEFjdGlvblByb3RvOgI4AUIfqgIcTUxBZ2VudHMuQ29t",
"bXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"cy5Db21tYW5kUHJvdG8SFAoMc2lkZV9jaGFubmVsGAUgASgMGk0KFExpc3RB",
"Z2VudEFjdGlvblByb3RvEjUKBXZhbHVlGAEgAygLMiYuY29tbXVuaWNhdG9y",
"X29iamVjdHMuQWdlbnRBY3Rpb25Qcm90bxpxChFBZ2VudEFjdGlvbnNFbnRy",
"eRILCgNrZXkYASABKAkSSwoFdmFsdWUYAiABKAsyPC5jb21tdW5pY2F0b3Jf",
"b2JqZWN0cy5Vbml0eVJMSW5wdXRQcm90by5MaXN0QWdlbnRBY3Rpb25Qcm90",
"bzoCOAFCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3Rv",
"Mw=="));
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null),
null, })
}));
}

EnvironmentParameters = other.environmentParameters_ != null ? other.EnvironmentParameters.Clone() : null;
isTraining_ = other.isTraining_;
command_ = other.command_;
sideChannel_ = other.sideChannel_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "side_channel" field.</summary>
public const int SideChannelFieldNumber = 5;
private pb::ByteString sideChannel_ = pb::ByteString.Empty;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pb::ByteString SideChannel {
get { return sideChannel_; }
set {
sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLInputProto);

if (!object.Equals(EnvironmentParameters, other.EnvironmentParameters)) return false;
if (IsTraining != other.IsTraining) return false;
if (Command != other.Command) return false;
if (SideChannel != other.SideChannel) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (environmentParameters_ != null) hash ^= EnvironmentParameters.GetHashCode();
if (IsTraining != false) hash ^= IsTraining.GetHashCode();
if (Command != 0) hash ^= Command.GetHashCode();
if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(32);
output.WriteEnum((int) Command);
}
if (SideChannel.Length != 0) {
output.WriteRawTag(42);
output.WriteBytes(SideChannel);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
if (Command != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command);
}
if (SideChannel.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

if (other.Command != 0) {
Command = other.Command;
}
if (other.SideChannel.Length != 0) {
SideChannel = other.SideChannel;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 32: {
command_ = (global::MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum();
break;
}
case 42: {
SideChannel = input.ReadBytes();
break;
}
}

44
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs


string.Concat(
"CjhtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js",
"X291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaM21sYWdlbnRz",
"L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byKj",
"L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byK5",
"bmZvc0VudHJ5GkkKEkxpc3RBZ2VudEluZm9Qcm90bxIzCgV2YWx1ZRgBIAMo",
"CzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFnZW50SW5mb1Byb3RvGm4KD0Fn",
"ZW50SW5mb3NFbnRyeRILCgNrZXkYASABKAkSSgoFdmFsdWUYAiABKAsyOy5j",
"b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uTGlzdEFn",
"ZW50SW5mb1Byb3RvOgI4AUoECAEQAkIfqgIcTUxBZ2VudHMuQ29tbXVuaWNh",
"dG9yT2JqZWN0c2IGcHJvdG8z"));
"bmZvc0VudHJ5EhQKDHNpZGVfY2hhbm5lbBgDIAEoDBpJChJMaXN0QWdlbnRJ",
"bmZvUHJvdG8SMwoFdmFsdWUYASADKAsyJC5jb21tdW5pY2F0b3Jfb2JqZWN0",
"cy5BZ2VudEluZm9Qcm90bxpuCg9BZ2VudEluZm9zRW50cnkSCwoDa2V5GAEg",
"ASgJEkoKBXZhbHVlGAIgASgLMjsuY29tbXVuaWNhdG9yX29iamVjdHMuVW5p",
"dHlSTE91dHB1dFByb3RvLkxpc3RBZ2VudEluZm9Qcm90bzoCOAFKBAgBEAJC",
"H6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null),
null, })
}));
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLOutputProto(UnityRLOutputProto other) : this() {
agentInfos_ = other.agentInfos_.Clone();
sideChannel_ = other.sideChannel_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return agentInfos_; }
}
/// <summary>Field number for the "side_channel" field.</summary>
public const int SideChannelFieldNumber = 3;
private pb::ByteString sideChannel_ = pb::ByteString.Empty;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pb::ByteString SideChannel {
get { return sideChannel_; }
set {
sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLOutputProto);

return true;
}
if (!AgentInfos.Equals(other.AgentInfos)) return false;
if (SideChannel != other.SideChannel) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= AgentInfos.GetHashCode();
if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
agentInfos_.WriteTo(output, _map_agentInfos_codec);
if (SideChannel.Length != 0) {
output.WriteRawTag(26);
output.WriteBytes(SideChannel);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += agentInfos_.CalculateSize(_map_agentInfos_codec);
if (SideChannel.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

return;
}
agentInfos_.Add(other.agentInfos_);
if (other.SideChannel.Length != 0) {
SideChannel = other.SideChannel;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

break;
case 18: {
agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec);
break;
}
case 26: {
SideChannel = input.ReadBytes();
break;
}
}

104
UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs


using System.Linq;
using UnityEngine;
using MLAgents.CommunicatorObjects;
using System.IO;
using Google.Protobuf;
namespace MLAgents
{

#endif
/// The communicator parameters sent at construction
CommunicatorInitParameters m_CommunicatorInitParameters;
Dictionary<int, SideChannel> m_SideChannels = new Dictionary<int, SideChannel>();
/// <summary>
/// Initializes a new instance of the RPCCommunicator class.

{
SendRLInputReceivedEvent(rlInput.IsTraining);
SendCommandEvent(rlInput.Command, rlInput.EnvironmentParameters);
ProcessSideChannelData(m_SideChannels, rlInput.SideChannel.ToArray());
}
UnityInputProto Initialize(UnityOutputProto unityOutput,

message.RlInitializationOutput = tempUnityRlInitializationOutput;
}
byte[] messageAggregated = GetSideChannelMessage(m_SideChannels);
message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated);
var input = Exchange(message);
UpdateSentBrainParameters(tempUnityRlInitializationOutput);

{
m_SentBrainKeys.Add(brainProto.BrainName);
m_UnsentBrainKeys.Remove(brainProto.BrainName);
}
}
#endregion
#region Handling side channels
/// <summary>
/// Registers a side channel to the communicator. The side channel will exchange
/// messages with its Python equivalent.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
public void RegisterSideChannel(SideChannel sideChannel)
{
if (m_SideChannels.ContainsKey(sideChannel.ChannelType()))
{
throw new UnityAgentsException(string.Format(
"A side channel with type index {} is already registered. You cannot register multiple " +
"side channels of the same type."));
}
m_SideChannels.Add(sideChannel.ChannelType(), sideChannel);
}
/// <summary>
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.
/// </summary>
/// <param name="sideChannels"> A dictionary of channel type to channel.</param>
/// <returns></returns>
public static byte[] GetSideChannelMessage(Dictionary<int, SideChannel> sideChannels)
{
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
foreach (var message in messageList)
{
binaryWriter.Write(sideChannel.ChannelType());
binaryWriter.Write(message.Count());
binaryWriter.Write(message);
}
sideChannel.MessageQueue.Clear();
}
return memStream.ToArray();
}
}
}
/// <summary>
/// Separates the data received from Python into individual messages for each registered side channel.
/// </summary>
/// <param name="sideChannels">A dictionary of channel type to channel.</param>
/// <param name="dataReceived">The byte array of data received from Python.</param>
public static void ProcessSideChannelData(Dictionary<int, SideChannel> sideChannels, byte[] dataReceived)
{
if (dataReceived.Length == 0)
{
return;
}
using (var memStream = new MemoryStream(dataReceived))
{
using (var binaryReader = new BinaryReader(memStream))
{
while (memStream.Position < memStream.Length)
{
int channelType = 0;
byte[] message = null;
try
{
channelType = binaryReader.ReadInt32();
var messageLength = binaryReader.ReadInt32();
message = binaryReader.ReadBytes(messageLength);
}
catch (Exception ex)
{
throw new UnityAgentsException(
"There was a problem reading a message in a SideChannel. Please make sure the " +
"version of MLAgents in Unity is compatible with the Python version. Original error : "
+ ex.Message);
}
if (sideChannels.ContainsKey(channelType))
{
sideChannels[channelType].OnMessageReceived(message);
}
else
{
Debug.Log(string.Format(
"Unknown side channel data received. Channel type "
+ ": {0}", channelType));
}
}
}
}
}

7
UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs


/// <param name="key">A key to identify which actions to get</param>
/// <returns></returns>
Dictionary<Agent, AgentAction> GetActions(string key);
/// <summary>
/// Registers a side channel to the communicator. The side channel will exchange
/// messages with its Python equivalent.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
void RegisterSideChannel(SideChannel sideChannel);
}
}

19
ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py


name='mlagents/envs/communicator_objects/unity_rl_input.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n7mlagents/envs/communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/agent_action.proto\x1a?mlagents/envs/communicator_objects/environment_parameters.proto\x1a\x30mlagents/envs/communicator_objects/command.proto\"\xc3\x03\n\x11UnityRLInputProto\x12P\n\ragent_actions\x18\x01 \x03(\x0b\x32\x39.communicator_objects.UnityRLInputProto.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1aq\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12K\n\x05value\x18\x02 \x01(\x0b\x32<.communicator_objects.UnityRLInputProto.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n7mlagents/envs/communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/agent_action.proto\x1a?mlagents/envs/communicator_objects/environment_parameters.proto\x1a\x30mlagents/envs/communicator_objects/command.proto\"\xd9\x03\n\x11UnityRLInputProto\x12P\n\ragent_actions\x18\x01 \x03(\x0b\x32\x39.communicator_objects.UnityRLInputProto.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x12\x14\n\x0cside_channel\x18\x05 \x01(\x0c\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1aq\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12K\n\x05value\x18\x02 \x01(\x0b\x32<.communicator_objects.UnityRLInputProto.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_command__pb2.DESCRIPTOR,])

extension_ranges=[],
oneofs=[
],
serialized_start=511,
serialized_end=588,
serialized_start=533,
serialized_end=610,
)
_UNITYRLINPUTPROTO_AGENTACTIONSENTRY = _descriptor.Descriptor(

extension_ranges=[],
oneofs=[
],
serialized_start=590,
serialized_end=703,
serialized_start=612,
serialized_end=725,
)
_UNITYRLINPUTPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='side_channel', full_name='communicator_objects.UnityRLInputProto.side_channel', index=4,
number=5, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=252,
serialized_end=703,
serialized_end=725,
)
_UNITYRLINPUTPROTO_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2._AGENTACTIONPROTO

6
ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi


is_training = ... # type: builtin___bool
command = ... # type: mlagents___envs___communicator_objects___command_pb2___CommandProto
side_channel = ... # type: builtin___bytes
@property
def agent_actions(self) -> typing___MutableMapping[typing___Text, UnityRLInputProto.ListAgentActionProto]: ...

environment_parameters : typing___Optional[mlagents___envs___communicator_objects___environment_parameters_pb2___EnvironmentParametersProto] = None,
is_training : typing___Optional[builtin___bool] = None,
command : typing___Optional[mlagents___envs___communicator_objects___command_pb2___CommandProto] = None,
side_channel : typing___Optional[builtin___bytes] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLInputProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"environment_parameters"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",u"command",u"environment_parameters",u"is_training"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",u"command",u"environment_parameters",u"is_training",u"side_channel"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",b"agent_actions",u"command",b"command",u"environment_parameters",b"environment_parameters",u"is_training",b"is_training"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",b"agent_actions",u"command",b"command",u"environment_parameters",b"environment_parameters",u"is_training",b"is_training",u"side_channel",b"side_channel"]) -> None: ...

19
ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py


name='mlagents/envs/communicator_objects/unity_rl_output.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/agent_info.proto\"\xa3\x02\n\x12UnityRLOutputProto\x12L\n\nagentInfos\x18\x02 \x03(\x0b\x32\x38.communicator_objects.UnityRLOutputProto.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1an\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.communicator_objects.UnityRLOutputProto.ListAgentInfoProto:\x02\x38\x01J\x04\x08\x01\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/agent_info.proto\"\xb9\x02\n\x12UnityRLOutputProto\x12L\n\nagentInfos\x18\x02 \x03(\x0b\x32\x38.communicator_objects.UnityRLOutputProto.AgentInfosEntry\x12\x14\n\x0cside_channel\x18\x03 \x01(\x0c\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1an\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.communicator_objects.UnityRLOutputProto.ListAgentInfoProto:\x02\x38\x01J\x04\x08\x01\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2.DESCRIPTOR,])

extension_ranges=[],
oneofs=[
],
serialized_start=236,
serialized_end=309,
serialized_start=258,
serialized_end=331,
)
_UNITYRLOUTPUTPROTO_AGENTINFOSENTRY = _descriptor.Descriptor(

extension_ranges=[],
oneofs=[
],
serialized_start=311,
serialized_end=421,
serialized_start=333,
serialized_end=443,
)
_UNITYRLOUTPUTPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='side_channel', full_name='communicator_objects.UnityRLOutputProto.side_channel', index=1,
number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=136,
serialized_end=427,
serialized_end=449,
)
_UNITYRLOUTPUTPROTO_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2._AGENTINFOPROTO

6
ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi


def HasField(self, field_name: typing_extensions___Literal[u"value",b"value"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> None: ...
side_channel = ... # type: builtin___bytes
@property
def agentInfos(self) -> typing___MutableMapping[typing___Text, UnityRLOutputProto.ListAgentInfoProto]: ...

agentInfos : typing___Optional[typing___Mapping[typing___Text, UnityRLOutputProto.ListAgentInfoProto]] = None,
side_channel : typing___Optional[builtin___bytes] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLOutputProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",u"side_channel"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",b"agentInfos"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",b"agentInfos",u"side_channel",b"side_channel"]) -> None: ...

58
ml-agents-envs/mlagents/envs/environment.py


import subprocess
from typing import Dict, List, Optional, Any
from mlagents.envs.side_channel.side_channel import SideChannel
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.timers import timed, hierarchical_timer
from .brain import AllBrainInfo, BrainInfo, BrainParameters

from .rpc_communicator import RpcCommunicator
from sys import platform
import signal
import struct
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mlagents.envs")

no_graphics: bool = False,
timeout_wait: int = 60,
args: Optional[List[str]] = None,
side_channels: Optional[List[SideChannel]] = None,
):
"""
Starts a new unity environment and establishes a connection with the environment.

:int timeout_wait: Time (in seconds) to wait for connection from environment.
:bool train_mode: Whether to run in training mode, speeding up the simulation, by default.
:list args: Addition Unity command line arguments
:list side_channels: Additional side channel for no-rl communication with Unity
"""
args = args or []
atexit.register(self._close)

self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
self.side_channels: Dict[int, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_type in self.side_channels:
raise UnityEnvironmentException(
"There cannot be two side channels with the same channel type {0}.".format(
_sc.channel_type
)
)
self.side_channels[_sc.channel_type] = _sc
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.

_data[brain_name] = BrainInfo.from_agent_proto(
self.worker_id, agent_info_list, self.brains[brain_name]
)
self._parse_side_channel_message(self.side_channels, output.side_channel)
@staticmethod
def _parse_side_channel_message(
side_channels: Dict[int, SideChannel], data: bytearray
) -> None:
offset = 0
while offset < len(data):
try:
channel_type, message_len = struct.unpack_from("<ii", data, offset)
offset = offset + 8
message_data = data[offset : offset + message_len]
offset = offset + message_len
except Exception:
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_type)
)
if channel_type in side_channels:
side_channels[channel_type].on_message_received(message_data)
else:
logger.warning(
"Unknown side channel data received. Channel type "
": {0}.".format(channel_type)
)
@staticmethod
def _generate_side_channel_data(side_channels: Dict[int, SideChannel]) -> bytearray:
result = bytearray()
for channel_type, channel in side_channels.items():
for message in channel.message_queue:
result += struct.pack("<ii", channel_type, len(message))
result += message
channel.message_queue = []
return result
def _update_brain_parameters(self, output: UnityOutputProto) -> None:
init_output = output.rl_initialization_output

action.value = float(value[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = 0
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)
def _generate_reset_input(

custom_reset_parameters
)
rl_in.command = 1
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)
def send_academy_parameters(

1
protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_input.proto


EnvironmentParametersProto environment_parameters = 2;
bool is_training = 3;
CommandProto command = 4;
bytes side_channel = 5;
}

1
protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_output.proto


}
reserved 1; // deprecated bool global_done field
map<string, ListAgentInfoProto> agentInfos = 2;
bytes side_channel = 3;
}

108
UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs


using System;
using NUnit.Framework;
using MLAgents;
using System.Collections.Generic;
using System.Text;
namespace MLAgents.Tests
{
public class SideChannelTests
{
// This test side channel only deals in integers
public class TestSideChannel : SideChannel
{
public List<int> m_MessagesReceived = new List<int>();
public override int ChannelType() { return -1; }
public override void OnMessageReceived(byte[] data)
{
m_MessagesReceived.Add(BitConverter.ToInt32(data, 0));
}
public void SendInt(int data)
{
QueueMessageToSend(BitConverter.GetBytes(data));
}
}
[Test]
public void TestIntegerSideChannel()
{
var intSender = new TestSideChannel();
var intReceiver = new TestSideChannel();
var dictSender = new Dictionary<int, SideChannel> { { intSender.ChannelType(), intSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { intReceiver.ChannelType(), intReceiver } };
intSender.SendInt(4);
intSender.SendInt(5);
intSender.SendInt(6);
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
Assert.AreEqual(intReceiver.m_MessagesReceived[0], 4);
Assert.AreEqual(intReceiver.m_MessagesReceived[1], 5);
Assert.AreEqual(intReceiver.m_MessagesReceived[2], 6);
}
[Test]
public void TestRawBytesSideChannel()
{
var str1 = "Test string";
var str2 = "Test string, second";
var strSender = new RawBytesChannel();
var strReceiver = new RawBytesChannel();
var dictSender = new Dictionary<int, SideChannel> { { strSender.ChannelType(), strSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { strReceiver.ChannelType(), strReceiver } };
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1));
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2));
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
var messages = strReceiver.GetAndClearReceivedMessages();
Assert.AreEqual(messages.Count, 2);
Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1);
Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2);
}
[Test]
public void TestFloatPropertiesSideChannel()
{
var k1 = "gravity";
var k2 = "length";
int wasCalled = 0;
var propA = new FloatPropertiesChannel();
var propB = new FloatPropertiesChannel();
var dictReceiver = new Dictionary<int, SideChannel> { { propA.ChannelType(), propA } };
var dictSender = new Dictionary<int, SideChannel> { { propB.ChannelType(), propB } };
propA.RegisterCallback(k1, f => { wasCalled++; });
var tmp = propB.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 3.0f);
propB.SetProperty(k2, 1.0f);
tmp = propB.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 1.0f);
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
tmp = propA.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 1.0f);
Assert.AreEqual(wasCalled, 0);
propB.SetProperty(k1, 1.0f);
Assert.AreEqual(wasCalled, 0);
fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
Assert.AreEqual(wasCalled, 1);
}
}
}

11
UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta


fileFormatVersion: 2
guid: 589f475debcdb479295a24799777b5e5
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

8
UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta


fileFormatVersion: 2
guid: cb2f03ed7ea59456380730bd0f9b5bcb
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

91
ml-agents-envs/mlagents/envs/tests/test_side_channel.py


import struct
from mlagents.envs.side_channel.side_channel import SideChannel
from mlagents.envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents.envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents.envs.environment import UnityEnvironment
class IntChannel(SideChannel):
def __init__(self):
self.list_int = []
super().__init__()
@property
def channel_type(self):
return -1
def on_message_received(self, data):
val = struct.unpack_from("<i", data, 0)[0]
self.list_int += [val]
def send_int(self, value):
data = bytearray()
data += struct.pack("<i", value)
super().queue_message_to_send(data)
def test_int_channel():
sender = IntChannel()
receiver = IntChannel()
sender.send_int(5)
sender.send_int(6)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
assert receiver.list_int[0] == 5
assert receiver.list_int[1] == 6
def test_float_properties():
sender = FloatPropertiesChannel()
receiver = FloatPropertiesChannel()
sender.set_property("prop1", 1.0)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
val = receiver.get_property("prop1")
assert val == 1.0
val = receiver.get_property("prop2")
assert val is None
sender.set_property("prop2", 2.0)
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
val = receiver.get_property("prop1")
assert val == 1.0
val = receiver.get_property("prop2")
assert val == 2.0
assert len(receiver.list_properties()) == 2
assert "prop1" in receiver.list_properties()
assert "prop2" in receiver.list_properties()
val = sender.get_property("prop1")
assert val == 1.0
def test_raw_bytes():
sender = RawBytesChannel()
receiver = RawBytesChannel()
sender.send_raw_data("foo".encode("ascii"))
sender.send_raw_data("bar".encode("ascii"))
data = UnityEnvironment._generate_side_channel_data({sender.channel_type: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_type: receiver}, data
)
messages = receiver.get_and_clear_received_messages()
assert len(messages) == 2
assert messages[0].decode("ascii") == "foo"
assert messages[1].decode("ascii") == "bar"
messages = receiver.get_and_clear_received_messages()
assert len(messages) == 0

36
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs


using System.Collections.Generic;
using System.IO;
using UnityEngine;
namespace MLAgents
{
public class EngineConfigurationChannel : SideChannel
{
public override int ChannelType()
{
return (int)SideChannelType.EngineSettings;
}
public override void OnMessageReceived(byte[] data)
{
using (var memStream = new MemoryStream(data))
{
using (var binaryReader = new BinaryReader(memStream))
{
var width = binaryReader.ReadInt32();
var height = binaryReader.ReadInt32();
var qualityLevel = binaryReader.ReadInt32();
var timeScale = binaryReader.ReadSingle();
var targetFrameRate = binaryReader.ReadInt32();
Screen.SetResolution(width, height, false);
QualitySettings.SetQualityLevel(qualityLevel, true);
Time.timeScale = timeScale;
Time.captureFramerate = 60;
Application.targetFrameRate = targetFrameRate;
}
}
}
}
}

11
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta


fileFormatVersion: 2
guid: 18ccdf3ce76784f2db68016fa284c33f
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

123
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs


using System.Collections.Generic;
using System.IO;
using System;
using System.Text;
namespace MLAgents
{
public interface IFloatProperties
{
/// <summary>
/// Sets one of the float properties of the environment. This data will be sent to Python.
/// </summary>
/// <param name="key"> The string identifier of the property.</param>
/// <param name="value"> The float value of the property.</param>
void SetProperty(string key, float value);
/// <summary>
/// Get an Environment property with a default value. If there is a value for this property,
/// it will be returned, otherwise, the default value will be returned.
/// </summary>
/// <param name="key"> The string identifier of the property.</param>
/// <param name="defaultValue"> The default value of the property.</param>
/// <returns></returns>
float GetPropertyWithDefault(string key, float defaultValue);
/// <summary>
/// Registers an action to be performed everytime the property is changed.
/// </summary>
/// <param name="key"> The string identifier of the property.</param>
/// <param name="action"> The action that ill be performed. Takes a float as input.</param>
void RegisterCallback(string key, Action<float> action);
/// <summary>
/// Returns a list of all the string identifiers of the properties currently present.
/// </summary>
/// <returns> The list of string identifiers </returns>
IList<string> ListProperties();
}
public class FloatPropertiesChannel : SideChannel, IFloatProperties
{
private Dictionary<string, float> m_FloatProperties = new Dictionary<string, float>();
private Dictionary<string, Action<float>> m_RegisteredActions = new Dictionary<string, Action<float>>();
public override int ChannelType()
{
return (int)SideChannelType.FloatProperties;
}
public override void OnMessageReceived(byte[] data)
{
var kv = DeserializeMessage(data);
m_FloatProperties[kv.Key] = kv.Value;
if (m_RegisteredActions.ContainsKey(kv.Key))
{
m_RegisteredActions[kv.Key].Invoke(kv.Value);
}
}
public void SetProperty(string key, float value)
{
m_FloatProperties[key] = value;
QueueMessageToSend(SerializeMessage(key, value));
if (m_RegisteredActions.ContainsKey(key))
{
m_RegisteredActions[key].Invoke(value);
}
}
public float GetPropertyWithDefault(string key, float defaultValue)
{
if (m_FloatProperties.ContainsKey(key))
{
return m_FloatProperties[key];
}
else
{
return defaultValue;
}
}
public void RegisterCallback(string key, Action<float> action)
{
m_RegisteredActions[key] = action;
}
public IList<string> ListProperties()
{
return new List<string>(m_FloatProperties.Keys);
}
private static KeyValuePair<string, float> DeserializeMessage(byte[] data)
{
using (var memStream = new MemoryStream(data))
{
using (var binaryReader = new BinaryReader(memStream))
{
var keyLength = binaryReader.ReadInt32();
var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength));
var value = binaryReader.ReadSingle();
return new KeyValuePair<string, float>(key, value);
}
}
}
private static byte[] SerializeMessage(string key, float value)
{
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
var stringEncoded = Encoding.ASCII.GetBytes(key);
binaryWriter.Write(stringEncoded.Length);
binaryWriter.Write(stringEncoded);
binaryWriter.Write(value);
return memStream.ToArray();
}
}
}
}
}

11
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta


fileFormatVersion: 2
guid: 452f8b3c01c4642aba645dcf0b6bfc6e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

65
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs


using System.Collections.Generic;
namespace MLAgents
{
public class RawBytesChannel : SideChannel
{
private List<byte[]> m_MessagesReceived = new List<byte[]>();
private int m_ChannelId;
/// <summary>
/// RawBytesChannel provides a way to exchange raw byte arrays between Unity and Python.
/// </summary>
/// <param name="channelId"> The identifier for the RawBytesChannel. Must be
/// the same on Python and Unity.</param>
public RawBytesChannel(int channelId = 0)
{
m_ChannelId = channelId;
}
public override int ChannelType()
{
return (int)SideChannelType.RawBytesChannelStart + m_ChannelId;
}
public override void OnMessageReceived(byte[] data)
{
m_MessagesReceived.Add(data);
}
/// <summary>
/// Sends the byte array message to the Python side channel. The message will be sent
/// alongside the simulation step.
/// </summary>
/// <param name="data"> The byte array of data to send to Python.</param>
public void SendRawBytes(byte[] data)
{
QueueMessageToSend(data);
}
/// <summary>
/// Gets the messages that were sent by python since the last call to
/// GetAndClearReceivedMessages.
/// </summary>
/// <returns> a list of byte array messages that Python has sent.</returns>
public IList<byte[]> GetAndClearReceivedMessages()
{
var result = new List<byte[]>();
result.AddRange(m_MessagesReceived);
m_MessagesReceived.Clear();
return result;
}
/// <summary>
/// Gets the messages that were sent by python since the last call to
/// GetAndClearReceivedMessages. Note that the messages received will not
/// be cleared with a call to GetReceivedMessages.
/// </summary>
/// <returns> a list of byte array messages that Python has sent.</returns>
public IList<byte[]> GetReceivedMessages()
{
var result = new List<byte[]>();
result.AddRange(m_MessagesReceived);
return result;
}
}
}

11
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta


fileFormatVersion: 2
guid: 40b01e9cdbfd94865b54ebeb4e5aeaa5
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

49
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs


using System.Collections.Generic;
namespace MLAgents
{
public enum SideChannelType
{
// Invalid side channel
Invalid = 0,
// Reserved for the FloatPropertiesChannel.
FloatProperties = 1,
//Reserved for the EngineConfigurationChannel.
EngineSettings = 2,
// Raw bytes channels should start here to avoid conflicting with other Unity ones.
RawBytesChannelStart = 1000,
// custom side channels should start here to avoid conflicting with Unity ones.
UserSideChannelStart = 2000,
}
public abstract class SideChannel
{
// The list of messages (byte arrays) that need to be sent to Python via the communicator.
// Should only ever be read and cleared by a ICommunicator object.
public List<byte[]> MessageQueue = new List<byte[]>();
/// <summary>
/// An int identifier for the SideChannel. Ensures that there is only ever one side channel
/// of each type. Ensure the Unity side channels will be linked to their Python equivalent.
/// </summary>
/// <returns> The integer identifier of the SideChannel</returns>
public abstract int ChannelType();
/// <summary>
/// Is called by the communicator every time a message is received from Python by the SideChannel.
/// Can be called multiple times per simulation step if multiple messages were sent.
/// </summary>
/// <param name="data"> the payload of the message.</param>
public abstract void OnMessageReceived(byte[] data);
/// <summary>
/// Queues a message to be sent to Python during the next simulation step.
/// </summary>
/// <param name="data"> The byte array of data to be sent to Python.</param>
protected void QueueMessageToSend(byte[] data)
{
MessageQueue.Add(data);
}
}
}

11
UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta


fileFormatVersion: 2
guid: 77b7d19dd6ce343eeba907540b5a2286
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

0
ml-agents-envs/mlagents/envs/side_channel/__init__.py

61
ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py


from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType
from mlagents.envs.exception import UnityCommunicationException
import struct
class EngineConfigurationChannel(SideChannel):
"""
This is the SideChannel for engine configuration exchange. The data in the
engine configuration is as follows :
- int width;
- int height;
- int qualityLevel;
- float timeScale;
- int targetFrameRate;
"""
@property
def channel_type(self) -> int:
return SideChannelType.EngineSettings
def on_message_received(self, data: bytearray) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
Note that Python should never receive an engine configuration from
Unity
"""
raise UnityCommunicationException(
"The EngineConfigurationChannel received a message from Unity, "
+ "this should not have happend."
)
def set_configuration(
self,
width: int = 80,
height: int = 80,
quality_level: int = 1,
time_scale: float = 20.0,
target_frame_rate: int = -1,
) -> None:
"""
Sets the engine configuration. Takes as input the configurations of the
engine.
:param width: Defines the width of the display. Default 80.
:param height: Defines the height of the display. Default 80.
:param quality_level: Defines the quality level of the simulation.
Default 1.
:param time_scale: Defines the multiplier for the deltatime in the
simulation. If set to a higher value, time will pass faaster in the
simulation but the physics might break. Default 20.
:param target_frame_rate: Instructs simulation to try to render at a
specified frame rate. Default -1.
"""
data = bytearray()
data += struct.pack("<i", width)
data += struct.pack("<i", height)
data += struct.pack("<i", quality_level)
data += struct.pack("<f", time_scale)
data += struct.pack("<i", target_frame_rate)
super().queue_message_to_send(data)

74
ml-agents-envs/mlagents/envs/side_channel/float_properties_channel.py


from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType
import struct
from typing import Tuple, Optional, List
class FloatPropertiesChannel(SideChannel):
"""
This is the SideChannel for float properties shared with Unity.
You can modify the float properties of an environment with the commands
set_property, get_property and list_properties.
"""
def __init__(self):
self._float_properties = {}
super().__init__()
@property
def channel_type(self) -> int:
return SideChannelType.FloatProperties
def on_message_received(self, data: bytearray) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
Note that Python should never receive an engine configuration from
Unity
"""
k, v = self.deserialize_float_prop(data)
self._float_properties[k] = v
def set_property(self, key: str, value: float) -> None:
"""
Sets a property in the Unity Environment.
:param key: The string identifier of the property.
:param value: The float value of the property.
"""
self._float_properties[key] = value
super().queue_message_to_send(self.serialize_float_prop(key, value))
def get_property(self, key: str) -> Optional[float]:
"""
Gets a property in the Unity Environment. If the property was not
found, will return None.
:param key: The string identifier of the property.
:return: The float value of the property or None.
"""
return self._float_properties.get(key)
def list_properties(self) -> List[str]:
"""
Returns a list of all the string identifiers of the properties
currently present in the Unity Environment.
"""
return self._float_properties.keys()
@staticmethod
def serialize_float_prop(key: str, value: float) -> bytearray:
result = bytearray()
encoded_key = key.encode("ascii")
result += struct.pack("<i", len(encoded_key))
result += encoded_key
result += struct.pack("<f", value)
return result
@staticmethod
def deserialize_float_prop(data: bytearray) -> Tuple[str, float]:
offset = 0
encoded_key_len = struct.unpack_from("<i", data, offset)[0]
offset = offset + 4
key = data[offset : offset + encoded_key_len].decode("ascii")
offset = offset + encoded_key_len
value = struct.unpack_from("<f", data, offset)[0]
return key, value

41
ml-agents-envs/mlagents/envs/side_channel/raw_bytes_channel.py


from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType
from typing import List
class RawBytesChannel(SideChannel):
"""
This is an example of what the SideChannel for raw bytes exchange would
look like. Is meant to be used for general research purpose.
"""
def __init__(self, channel_id=0):
self._received_messages = []
self._channel_id = channel_id
super().__init__()
@property
def channel_type(self) -> int:
return SideChannelType.RawBytesChannelStart + self._channel_id
def on_message_received(self, data: bytearray) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
"""
self._received_messages.append(data)
def get_and_clear_received_messages(self) -> List[bytearray]:
"""
returns a list of bytearray received from the environment.
"""
result = list(self._received_messages)
self._received_messages = []
return result
def send_raw_data(self, data: bytearray) -> None:
"""
Queues a message to be sent by the environment at the next call to
step.
"""
super().queue_message_to_send(data)

51
ml-agents-envs/mlagents/envs/side_channel/side_channel.py


from abc import ABC, abstractmethod
from enum import IntEnum
class SideChannelType(IntEnum):
FloatProperties = 1
EngineSettings = 2
# Raw bytes channels should start here to avoid conflicting with other
# Unity ones.
RawBytesChannelStart = 1000
# custom side channels should start here to avoid conflicting with Unity
# ones.
UserSideChannelStart = 2000
class SideChannel(ABC):
"""
The side channel just get access to a bytes buffer that will be shared
between C# and Python. For example, We will create a specific side channel
for properties that will be a list of string (fixed size) to float number,
that can be modified by both C# and Python. All side channels are passed
to the Env object at construction.
"""
def __init__(self):
self.message_queue = []
def queue_message_to_send(self, data: bytearray) -> None:
"""
Queues a message to be sent by the environment at the next call to
step.
"""
self.message_queue.append(data)
@abstractmethod
def on_message_received(self, data: bytearray) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
"""
pass
@property
@abstractmethod
def channel_type(self) -> int:
"""
:return:The type of side channel used. Will influence how the data is
processed in the environment.
"""
pass
正在加载...
取消
保存