浏览代码
[WIP] Side Channel Design Changes (#3807)
[WIP] Side Channel Design Changes (#3807)
* Make EnvironmentParameters a first-class citizen in the API Missing: Python conterparts and testing. * Minor comment fix to Engine Parameters * A second minor fix. * Make EngineConfigChannel Internal and add a singleton/sealed accessor * Make StatsSideChannel Internal and add a singleton/sealed accessor * Changes to SideChannelUtils - Disallow two sidechannels of the same type to be added - Remove GetSideChannels that return a list as that is now unnecessary - Make most methods except (register/unregister) internal to limit users impacting the “system-level” side channels - Add an improved comment to SideChannel.cs * Added Dispose methods to system-level sidechannel wrappers - Specifically to StatsRecorder, EnvironmentParameters and EngineParameters. - Updated Academy.Dispose to take advantage of these. - Updated Editor tests to cover all three “system-level” side channels. Kudos to Unit Tests (TestAcade.../develop/dockerfile
GitHub
5 年前
当前提交
ea0c6fa0
共有 49 个文件被更改,包括 915 次插入 和 496 次删除
-
8Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
-
8Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
-
6Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
-
7Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
-
10Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs
-
9Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
-
20Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
-
2Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs
-
18Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
-
13Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
-
3Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs
-
8Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
-
6Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs
-
8Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
-
10Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
-
12Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
-
23com.unity.ml-agents/CHANGELOG.md
-
42com.unity.ml-agents/Runtime/Academy.cs
-
4com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
-
58com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs
-
30com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs
-
2com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs
-
17com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs
-
45com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs
-
2com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta
-
16com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
32com.unity.ml-agents/Tests/Editor/SideChannelTests.cs
-
10docs/Custom-SideChannels.md
-
27docs/Migrating.md
-
36docs/Python-API.md
-
4docs/Training-Curriculum-Learning.md
-
4docs/Using-Tensorboard.md
-
2ml-agents-envs/mlagents_envs/environment.py
-
88ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
-
5ml-agents/mlagents/trainers/env_manager.py
-
19ml-agents/mlagents/trainers/learn.py
-
14ml-agents/mlagents/trainers/simple_env_manager.py
-
19ml-agents/mlagents/trainers/subprocess_env_manager.py
-
6ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
70com.unity.ml-agents/Runtime/EnvironmentParameters.cs
-
11com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta
-
91com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
-
218com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs
-
11com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs.meta
-
71com.unity.ml-agents/Runtime/StatsRecorder.cs
-
11com.unity.ml-agents/Runtime/StatsRecorder.cs.meta
-
37ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
-
238com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs
-
0/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using MLAgents.SideChannels; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// A container for the Environment Parameters that may be modified during training.
|
|||
/// The keys for those parameters are defined in the trainer configurations and the
|
|||
/// the values are generated from the training process in features such as Curriculum Learning
|
|||
/// and Environment Parameter Randomization.
|
|||
///
|
|||
/// One current assumption for all the environment parameters is that they are of type float.
|
|||
/// </summary>
|
|||
public sealed class EnvironmentParameters |
|||
{ |
|||
/// <summary>
|
|||
/// The side channel that is used to receive the new parameter values.
|
|||
/// </summary>
|
|||
readonly EnvironmentParametersChannel m_Channel; |
|||
|
|||
/// <summary>
|
|||
/// Constructor.
|
|||
/// </summary>
|
|||
internal EnvironmentParameters() |
|||
{ |
|||
m_Channel = new EnvironmentParametersChannel(); |
|||
SideChannelsManager.RegisterSideChannel(m_Channel); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the parameter value for the specified key. Returns the default value provided
|
|||
/// if this parameter key does not have a value. Only returns a parameter value if it is
|
|||
/// of type float.
|
|||
/// </summary>
|
|||
/// <param name="key">The parameter key</param>
|
|||
/// <param name="defaultValue">Default value for this parameter.</param>
|
|||
/// <returns></returns>
|
|||
public float GetWithDefault(string key, float defaultValue) |
|||
{ |
|||
return m_Channel.GetWithDefault(key, defaultValue); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Registers a callback action for the provided parameter key. Will overwrite any
|
|||
/// existing action for that parameter. The callback will be called whenever the parameter
|
|||
/// receives a value from the training process.
|
|||
/// </summary>
|
|||
/// <param name="key">The parameter key</param>
|
|||
/// <param name="action">The callback action</param>
|
|||
public void RegisterCallback(string key, Action<float> action) |
|||
{ |
|||
m_Channel.RegisterCallback(key, action); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a list of all the parameter keys that have received values.
|
|||
/// </summary>
|
|||
/// <returns>List of parameter keys.</returns>
|
|||
public IList<string> Keys() |
|||
{ |
|||
return m_Channel.ListParameters(); |
|||
} |
|||
|
|||
internal void Dispose() |
|||
{ |
|||
SideChannelsManager.UnregisterSideChannel(m_Channel); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 90ce0b26bef35484890eac0633b85eed |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Lists the different data types supported.
|
|||
/// </summary>
|
|||
internal enum EnvironmentDataTypes |
|||
{ |
|||
Float = 0 |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// A side channel that manages the environment parameter values from Python. Currently
|
|||
/// limited to parameters of type float.
|
|||
/// </summary>
|
|||
internal class EnvironmentParametersChannel : SideChannel |
|||
{ |
|||
Dictionary<string, float> m_Parameters = new Dictionary<string, float>(); |
|||
Dictionary<string, Action<float>> m_RegisteredActions = |
|||
new Dictionary<string, Action<float>>(); |
|||
|
|||
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400"; |
|||
|
|||
/// <summary>
|
|||
/// Initializes the side channel. The constructor is internal because only one instance is
|
|||
/// supported at a time, and is created by the Academy.
|
|||
/// </summary>
|
|||
internal EnvironmentParametersChannel() |
|||
{ |
|||
ChannelId = new Guid(k_EnvParamsId); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected override void OnMessageReceived(IncomingMessage msg) |
|||
{ |
|||
var key = msg.ReadString(); |
|||
var type = msg.ReadInt32(); |
|||
if ((int)EnvironmentDataTypes.Float == type) |
|||
{ |
|||
var value = msg.ReadFloat32(); |
|||
|
|||
m_Parameters[key] = value; |
|||
|
|||
Action<float> action; |
|||
m_RegisteredActions.TryGetValue(key, out action); |
|||
action?.Invoke(value); |
|||
} |
|||
else |
|||
{ |
|||
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the parameter value associated with the provided key. Returns the default
|
|||
/// value if one doesn't exist.
|
|||
/// </summary>
|
|||
/// <param name="key">Parameter key.</param>
|
|||
/// <param name="defaultValue">Default value to return.</param>
|
|||
/// <returns></returns>
|
|||
public float GetWithDefault(string key, float defaultValue) |
|||
{ |
|||
float valueOut; |
|||
bool hasKey = m_Parameters.TryGetValue(key, out valueOut); |
|||
return hasKey ? valueOut : defaultValue; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Registers a callback for the associated parameter key. Will overwrite any existing
|
|||
/// actions for this parameter key.
|
|||
/// </summary>
|
|||
/// <param name="key">The parameter key.</param>
|
|||
/// <param name="action">The callback.</param>
|
|||
public void RegisterCallback(string key, Action<float> action) |
|||
{ |
|||
m_RegisteredActions[key] = action; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns all parameter keys that have a registered value.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public IList<string> ListParameters() |
|||
{ |
|||
return new List<string>(m_Parameters.Keys); |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using System.IO; |
|||
|
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Collection of static utilities for managing the registering/unregistering of
|
|||
/// <see cref="SideChannels"/> and the sending/receiving of messages for all the channels.
|
|||
/// </summary>
|
|||
public static class SideChannelsManager |
|||
{ |
|||
private static Dictionary<Guid, SideChannel> RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
|||
|
|||
private struct CachedSideChannelMessage |
|||
{ |
|||
public Guid ChannelId; |
|||
public byte[] Message; |
|||
} |
|||
|
|||
private static readonly Queue<CachedSideChannelMessage> m_CachedMessages = |
|||
new Queue<CachedSideChannelMessage>(); |
|||
|
|||
/// <summary>
|
|||
/// Register a side channel to begin sending and receiving messages. This method is
|
|||
/// available for environments that have custom side channels. All built-in side
|
|||
/// channels within the ML-Agents Toolkit are managed internally and do not need to
|
|||
/// be explicitly registered/unregistered. A side channel may only be registered once.
|
|||
/// </summary>
|
|||
/// <param name="sideChannel">The side channel to register.</param>
|
|||
public static void RegisterSideChannel(SideChannel sideChannel) |
|||
{ |
|||
var channelId = sideChannel.ChannelId; |
|||
if (RegisteredChannels.ContainsKey(channelId)) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
$"A side channel with id {channelId} is already registered. " + |
|||
"You cannot register multiple side channels of the same id."); |
|||
} |
|||
|
|||
// Process any messages that we've already received for this channel ID.
|
|||
var numMessages = m_CachedMessages.Count; |
|||
for (var i = 0; i < numMessages; i++) |
|||
{ |
|||
var cachedMessage = m_CachedMessages.Dequeue(); |
|||
if (channelId == cachedMessage.ChannelId) |
|||
{ |
|||
sideChannel.ProcessMessage(cachedMessage.Message); |
|||
} |
|||
else |
|||
{ |
|||
m_CachedMessages.Enqueue(cachedMessage); |
|||
} |
|||
} |
|||
RegisteredChannels.Add(channelId, sideChannel); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Unregister a side channel to stop sending and receiving messages. This method is
|
|||
/// available for environments that have custom side channels. All built-in side
|
|||
/// channels within the ML-Agents Toolkit are managed internally and do not need to
|
|||
/// be explicitly registered/unregistered. Unregistering a side channel that has already
|
|||
/// been unregistered (or never registered in the first place) has no negative side effects.
|
|||
/// Note that unregistering a side channel may not stop the Python side
|
|||
/// from sending messages, but it does mean that sent messages with not result in a call
|
|||
/// to <see cref="SideChannel.OnMessageReceived(IncomingMessage)"/>. Furthermore,
|
|||
/// those messages will not be buffered and will, in essence, be lost.
|
|||
/// </summary>
|
|||
/// <param name="sideChannel">The side channel to unregister.</param>
|
|||
public static void UnregisterSideChannel(SideChannel sideChannel) |
|||
{ |
|||
if (RegisteredChannels.ContainsKey(sideChannel.ChannelId)) |
|||
{ |
|||
RegisteredChannels.Remove(sideChannel.ChannelId); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Unregisters all the side channels from the communicator.
|
|||
/// </summary>
|
|||
internal static void UnregisterAllSideChannels() |
|||
{ |
|||
RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
|
|||
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
|
|||
/// </summary>
|
|||
/// <typeparam name="T"></typeparam>
|
|||
/// <returns></returns>
|
|||
internal static T GetSideChannel<T>() where T: SideChannel |
|||
{ |
|||
foreach (var sc in RegisteredChannels.Values) |
|||
{ |
|||
if (sc.GetType() == typeof(T)) |
|||
{ |
|||
return (T) sc; |
|||
} |
|||
} |
|||
return null; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Grabs the messages that the registered side channels will send to Python at the current step
|
|||
/// into a singe byte array.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal static byte[] GetSideChannelMessage() |
|||
{ |
|||
return GetSideChannelMessage(RegisteredChannels); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Grabs the messages that the registered side channels will send to Python at the current step
|
|||
/// into a singe byte array.
|
|||
/// </summary>
|
|||
/// <param name="sideChannels"> A dictionary of channel type to channel.</param>
|
|||
/// <returns></returns>
|
|||
internal static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels) |
|||
{ |
|||
using (var memStream = new MemoryStream()) |
|||
{ |
|||
using (var binaryWriter = new BinaryWriter(memStream)) |
|||
{ |
|||
foreach (var sideChannel in sideChannels.Values) |
|||
{ |
|||
var messageList = sideChannel.MessageQueue; |
|||
foreach (var message in messageList) |
|||
{ |
|||
binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
|||
binaryWriter.Write(message.Length); |
|||
binaryWriter.Write(message); |
|||
} |
|||
sideChannel.MessageQueue.Clear(); |
|||
} |
|||
return memStream.ToArray(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Separates the data received from Python into individual messages for each registered side channel.
|
|||
/// </summary>
|
|||
/// <param name="dataReceived">The byte array of data received from Python.</param>
|
|||
internal static void ProcessSideChannelData(byte[] dataReceived) |
|||
{ |
|||
ProcessSideChannelData(RegisteredChannels, dataReceived); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Separates the data received from Python into individual messages for each registered side channel.
|
|||
/// </summary>
|
|||
/// <param name="sideChannels">A dictionary of channel type to channel.</param>
|
|||
/// <param name="dataReceived">The byte array of data received from Python.</param>
|
|||
internal static void ProcessSideChannelData(Dictionary<Guid, SideChannel> sideChannels, byte[] dataReceived) |
|||
{ |
|||
while (m_CachedMessages.Count != 0) |
|||
{ |
|||
var cachedMessage = m_CachedMessages.Dequeue(); |
|||
if (sideChannels.ContainsKey(cachedMessage.ChannelId)) |
|||
{ |
|||
sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message); |
|||
} |
|||
else |
|||
{ |
|||
Debug.Log(string.Format( |
|||
"Unknown side channel data received. Channel Id is " |
|||
+ ": {0}", cachedMessage.ChannelId)); |
|||
} |
|||
} |
|||
|
|||
if (dataReceived.Length == 0) |
|||
{ |
|||
return; |
|||
} |
|||
using (var memStream = new MemoryStream(dataReceived)) |
|||
{ |
|||
using (var binaryReader = new BinaryReader(memStream)) |
|||
{ |
|||
while (memStream.Position < memStream.Length) |
|||
{ |
|||
Guid channelId = Guid.Empty; |
|||
byte[] message = null; |
|||
try |
|||
{ |
|||
channelId = new Guid(binaryReader.ReadBytes(16)); |
|||
var messageLength = binaryReader.ReadInt32(); |
|||
message = binaryReader.ReadBytes(messageLength); |
|||
} |
|||
catch (Exception ex) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"There was a problem reading a message in a SideChannel. Please make sure the " + |
|||
"version of MLAgents in Unity is compatible with the Python version. Original error : " |
|||
+ ex.Message); |
|||
} |
|||
if (sideChannels.ContainsKey(channelId)) |
|||
{ |
|||
sideChannels[channelId].ProcessMessage(message); |
|||
} |
|||
else |
|||
{ |
|||
// Don't recognize this ID, but cache it in case the SideChannel that can handle
|
|||
// it is registered before the next call to ProcessSideChannelData.
|
|||
m_CachedMessages.Enqueue(new CachedSideChannelMessage |
|||
{ |
|||
ChannelId = channelId, |
|||
Message = message |
|||
}); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ccc0d134445f947349c68a6d07e3cdc2 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using MLAgents.SideChannels; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// Determines the behavior of how multiple stats within the same summary period are combined.
|
|||
/// </summary>
|
|||
public enum StatAggregationMethod |
|||
{ |
|||
/// <summary>
|
|||
/// Values within the summary period are averaged before reporting.
|
|||
/// Note that values from the same C# environment in the same step may replace each other.
|
|||
/// </summary>
|
|||
Average = 0, |
|||
|
|||
/// <summary>
|
|||
/// Only the most recent value is reported.
|
|||
/// To avoid conflicts when training with multiple concurrent environments, only
|
|||
/// stats from worker index 0 will be tracked.
|
|||
/// </summary>
|
|||
MostRecent = 1 |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add stats (key-value pairs) for reporting. These values will sent these to a StatsReporter
|
|||
/// instance, which means the values will appear in the TensorBoard summary, as well as trainer
|
|||
/// gauges. You can nest stats in TensorBoard by adding "/" in the name (e.g. "Agent/Health"
|
|||
/// and "Agent/Wallet"). Note that stats are only written to TensorBoard each summary_frequency
|
|||
/// steps (a trainer configuration). If a stat is received multiple times, within that period
|
|||
/// then the values will be aggregated using the <see cref="StatAggregationMethod"/> provided.
|
|||
/// </summary>
|
|||
public sealed class StatsRecorder |
|||
{ |
|||
/// <summary>
|
|||
/// The side channel that is used to receive the new parameter values.
|
|||
/// </summary>
|
|||
readonly StatsSideChannel m_Channel; |
|||
|
|||
/// <summary>
|
|||
/// Constructor.
|
|||
/// </summary>
|
|||
internal StatsRecorder() |
|||
{ |
|||
m_Channel = new StatsSideChannel(); |
|||
SideChannelsManager.RegisterSideChannel(m_Channel); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add a stat value for reporting.
|
|||
/// </summary>
|
|||
/// <param name="key">The stat name.</param>
|
|||
/// <param name="value">
|
|||
/// The stat value. You can nest stats in TensorBoard by using "/".
|
|||
/// </param>
|
|||
/// <param name="aggregationMethod">
|
|||
/// How multiple values sent in the same summary window should be treated.
|
|||
/// </param>
|
|||
public void Add( |
|||
string key, |
|||
float value, |
|||
StatAggregationMethod aggregationMethod = StatAggregationMethod.Average) |
|||
{ |
|||
m_Channel.AddStat(key, value, aggregationMethod); |
|||
} |
|||
|
|||
internal void Dispose() |
|||
{ |
|||
SideChannelsManager.UnregisterSideChannel(m_Channel); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d9add8900e8a746e6a4cb410cb27d664 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage |
|||
from mlagents_envs.exception import UnityCommunicationException |
|||
import uuid |
|||
from enum import IntEnum |
|||
|
|||
|
|||
class EnvironmentParametersChannel(SideChannel): |
|||
""" |
|||
This is the SideChannel for sending environment parameters to Unity. |
|||
You can send parameters to an environment with the command |
|||
set_float_parameter. |
|||
""" |
|||
|
|||
class EnvironmentDataTypes(IntEnum): |
|||
FLOAT = 0 |
|||
|
|||
def __init__(self) -> None: |
|||
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400")) |
|||
super().__init__(channel_id) |
|||
|
|||
def on_message_received(self, msg: IncomingMessage) -> None: |
|||
raise UnityCommunicationException( |
|||
"The EnvironmentParametersChannel received a message from Unity, " |
|||
+ "this should not have happend." |
|||
) |
|||
|
|||
def set_float_parameter(self, key: str, value: float) -> None: |
|||
""" |
|||
Sets a float environment parameter in the Unity Environment. |
|||
:param key: The string identifier of the parameter. |
|||
:param value: The float value of the parameter. |
|||
""" |
|||
msg = OutgoingMessage() |
|||
msg.write_string(key) |
|||
msg.write_int32(self.EnvironmentDataTypes.FLOAT) |
|||
msg.write_float32(value) |
|||
super().queue_message_to_send(msg) |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using System.IO; |
|||
|
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Collection of static utilities for managing the registering/unregistering of
|
|||
/// <see cref="SideChannels"/> and the sending/receiving of messages for all the channels.
|
|||
/// </summary>
|
|||
public static class SideChannelUtils |
|||
{ |
|||
|
|||
private static Dictionary<Guid, SideChannel> RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
|||
|
|||
private struct CachedSideChannelMessage |
|||
{ |
|||
public Guid ChannelId; |
|||
public byte[] Message; |
|||
} |
|||
|
|||
private static Queue<CachedSideChannelMessage> m_CachedMessages = new Queue<CachedSideChannelMessage>(); |
|||
|
|||
/// <summary>
|
|||
/// Registers a side channel to the communicator. The side channel will exchange
|
|||
/// messages with its Python equivalent.
|
|||
/// </summary>
|
|||
/// <param name="sideChannel"> The side channel to be registered.</param>
|
|||
public static void RegisterSideChannel(SideChannel sideChannel) |
|||
{ |
|||
var channelId = sideChannel.ChannelId; |
|||
if (RegisteredChannels.ContainsKey(channelId)) |
|||
{ |
|||
throw new UnityAgentsException(string.Format( |
|||
"A side channel with type index {0} is already registered. You cannot register multiple " + |
|||
"side channels of the same id.", channelId)); |
|||
} |
|||
|
|||
// Process any messages that we've already received for this channel ID.
|
|||
var numMessages = m_CachedMessages.Count; |
|||
for (int i = 0; i < numMessages; i++) |
|||
{ |
|||
var cachedMessage = m_CachedMessages.Dequeue(); |
|||
if (channelId == cachedMessage.ChannelId) |
|||
{ |
|||
using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) |
|||
{ |
|||
sideChannel.OnMessageReceived(incomingMsg); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
m_CachedMessages.Enqueue(cachedMessage); |
|||
} |
|||
} |
|||
RegisteredChannels.Add(channelId, sideChannel); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Unregisters a side channel from the communicator.
|
|||
/// </summary>
|
|||
/// <param name="sideChannel"> The side channel to be unregistered.</param>
|
|||
public static void UnregisterSideChannel(SideChannel sideChannel) |
|||
{ |
|||
if (RegisteredChannels.ContainsKey(sideChannel.ChannelId)) |
|||
{ |
|||
RegisteredChannels.Remove(sideChannel.ChannelId); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Unregisters all the side channels from the communicator.
|
|||
/// </summary>
|
|||
public static void UnregisterAllSideChannels() |
|||
{ |
|||
RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
|
|||
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
|
|||
/// </summary>
|
|||
/// <typeparam name="T"></typeparam>
|
|||
/// <returns></returns>
|
|||
public static T GetSideChannel<T>() where T: SideChannel |
|||
{ |
|||
foreach (var sc in RegisteredChannels.Values) |
|||
{ |
|||
if (sc.GetType() == typeof(T)) |
|||
{ |
|||
return (T) sc; |
|||
} |
|||
} |
|||
return null; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns all SideChannels of Type T that are registered. Use <see cref="GetSideChannel{T}()"/> if possible,
|
|||
/// as that does not make any memory allocations.
|
|||
/// </summary>
|
|||
/// <typeparam name="T"></typeparam>
|
|||
/// <returns></returns>
|
|||
public static List<T> GetSideChannels<T>() where T: SideChannel |
|||
{ |
|||
var output = new List<T>(); |
|||
|
|||
foreach (var sc in RegisteredChannels.Values) |
|||
{ |
|||
if (sc.GetType() == typeof(T)) |
|||
{ |
|||
output.Add((T) sc); |
|||
} |
|||
} |
|||
return output; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Grabs the messages that the registered side channels will send to Python at the current step
|
|||
/// into a singe byte array.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal static byte[] GetSideChannelMessage() |
|||
{ |
|||
return GetSideChannelMessage(RegisteredChannels); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Grabs the messages that the registered side channels will send to Python at the current step
|
|||
/// into a singe byte array.
|
|||
/// </summary>
|
|||
/// <param name="sideChannels"> A dictionary of channel type to channel.</param>
|
|||
/// <returns></returns>
|
|||
internal static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels) |
|||
{ |
|||
using (var memStream = new MemoryStream()) |
|||
{ |
|||
using (var binaryWriter = new BinaryWriter(memStream)) |
|||
{ |
|||
foreach (var sideChannel in sideChannels.Values) |
|||
{ |
|||
var messageList = sideChannel.MessageQueue; |
|||
foreach (var message in messageList) |
|||
{ |
|||
binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
|||
binaryWriter.Write(message.Length); |
|||
binaryWriter.Write(message); |
|||
} |
|||
sideChannel.MessageQueue.Clear(); |
|||
} |
|||
return memStream.ToArray(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Separates the data received from Python into individual messages for each registered side channel.
|
|||
/// </summary>
|
|||
/// <param name="dataReceived">The byte array of data received from Python.</param>
|
|||
internal static void ProcessSideChannelData(byte[] dataReceived) |
|||
{ |
|||
ProcessSideChannelData(RegisteredChannels, dataReceived); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Separates the data received from Python into individual messages for each registered side channel.
|
|||
/// </summary>
|
|||
/// <param name="sideChannels">A dictionary of channel type to channel.</param>
|
|||
/// <param name="dataReceived">The byte array of data received from Python.</param>
|
|||
internal static void ProcessSideChannelData(Dictionary<Guid, SideChannel> sideChannels, byte[] dataReceived) |
|||
{ |
|||
while (m_CachedMessages.Count != 0) |
|||
{ |
|||
var cachedMessage = m_CachedMessages.Dequeue(); |
|||
if (sideChannels.ContainsKey(cachedMessage.ChannelId)) |
|||
{ |
|||
using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) |
|||
{ |
|||
sideChannels[cachedMessage.ChannelId].OnMessageReceived(incomingMsg); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
Debug.Log(string.Format( |
|||
"Unknown side channel data received. Channel Id is " |
|||
+ ": {0}", cachedMessage.ChannelId)); |
|||
} |
|||
} |
|||
|
|||
if (dataReceived.Length == 0) |
|||
{ |
|||
return; |
|||
} |
|||
using (var memStream = new MemoryStream(dataReceived)) |
|||
{ |
|||
using (var binaryReader = new BinaryReader(memStream)) |
|||
{ |
|||
while (memStream.Position < memStream.Length) |
|||
{ |
|||
Guid channelId = Guid.Empty; |
|||
byte[] message = null; |
|||
try |
|||
{ |
|||
channelId = new Guid(binaryReader.ReadBytes(16)); |
|||
var messageLength = binaryReader.ReadInt32(); |
|||
message = binaryReader.ReadBytes(messageLength); |
|||
} |
|||
catch (Exception ex) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"There was a problem reading a message in a SideChannel. Please make sure the " + |
|||
"version of MLAgents in Unity is compatible with the Python version. Original error : " |
|||
+ ex.Message); |
|||
} |
|||
if (sideChannels.ContainsKey(channelId)) |
|||
{ |
|||
using (var incomingMsg = new IncomingMessage(message)) |
|||
{ |
|||
sideChannels[channelId].OnMessageReceived(incomingMsg); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
// Don't recognize this ID, but cache it in case the SideChannel that can handle
|
|||
// it is registered before the next call to ProcessSideChannelData.
|
|||
m_CachedMessages.Enqueue(new CachedSideChannelMessage |
|||
{ |
|||
ChannelId = channelId, |
|||
Message = message |
|||
}); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue