using System;
using System.Collections.Generic;
using UnityEngine;
using System.IO;
namespace Unity.MLAgents.SideChannels
{
///
/// Collection of static utilities for managing the registering/unregistering of
/// and the sending/receiving of messages for all the channels.
///
public static class SideChannelManager
{
static Dictionary s_RegisteredChannels = new Dictionary();
struct CachedSideChannelMessage
{
public Guid ChannelId;
public byte[] Message;
}
static readonly Queue s_CachedMessages =
new Queue();
///
/// Register a side channel to begin sending and receiving messages. This method is
/// available for environments that have custom side channels. All built-in side
/// channels within the ML-Agents Toolkit are managed internally and do not need to
/// be explicitly registered/unregistered. A side channel may only be registered once.
///
/// The side channel to register.
public static void RegisterSideChannel(SideChannel sideChannel)
{
var channelId = sideChannel.ChannelId;
if (s_RegisteredChannels.ContainsKey(channelId))
{
throw new UnityAgentsException(
$"A side channel with id {channelId} is already registered. " +
"You cannot register multiple side channels of the same id.");
}
// Process any messages that we've already received for this channel ID.
var numMessages = s_CachedMessages.Count;
for (var i = 0; i < numMessages; i++)
{
var cachedMessage = s_CachedMessages.Dequeue();
if (channelId == cachedMessage.ChannelId)
{
sideChannel.ProcessMessage(cachedMessage.Message);
}
else
{
s_CachedMessages.Enqueue(cachedMessage);
}
}
s_RegisteredChannels.Add(channelId, sideChannel);
}
///
/// Unregister a side channel to stop sending and receiving messages. This method is
/// available for environments that have custom side channels. All built-in side
/// channels within the ML-Agents Toolkit are managed internally and do not need to
/// be explicitly registered/unregistered. Unregistering a side channel that has already
/// been unregistered (or never registered in the first place) has no negative side effects.
/// Note that unregistering a side channel may not stop the Python side
/// from sending messages, but it does mean that sent messages with not result in a call
/// to . Furthermore,
/// those messages will not be buffered and will, in essence, be lost.
///
/// The side channel to unregister.
public static void UnregisterSideChannel(SideChannel sideChannel)
{
if (s_RegisteredChannels.ContainsKey(sideChannel.ChannelId))
{
s_RegisteredChannels.Remove(sideChannel.ChannelId);
}
}
///
/// Unregisters all the side channels from the communicator.
///
internal static void UnregisterAllSideChannels()
{
s_RegisteredChannels = new Dictionary();
}
///
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
///
///
///
internal static T GetSideChannel() where T : SideChannel
{
foreach (var sc in s_RegisteredChannels.Values)
{
if (sc.GetType() == typeof(T))
{
return (T)sc;
}
}
return null;
}
///
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.
///
///
internal static byte[] GetSideChannelMessage()
{
return GetSideChannelMessage(s_RegisteredChannels);
}
///
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.
///
/// A dictionary of channel type to channel.
///
internal static byte[] GetSideChannelMessage(Dictionary sideChannels)
{
if (!HasOutgoingMessages(sideChannels))
{
// Early out so that we don't create the MemoryStream or BinaryWriter.
// This is the most common case.
return Array.Empty();
}
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
foreach (var message in messageList)
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
sideChannel.MessageQueue.Clear();
}
return memStream.ToArray();
}
}
}
///
/// Check whether any of the sidechannels have queued messages.
///
///
///
static bool HasOutgoingMessages(Dictionary sideChannels)
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
if (messageList.Count > 0)
{
return true;
}
}
return false;
}
///
/// Separates the data received from Python into individual messages for each registered side channel.
///
/// The byte array of data received from Python.
internal static void ProcessSideChannelData(byte[] dataReceived)
{
ProcessSideChannelData(s_RegisteredChannels, dataReceived);
}
///
/// Separates the data received from Python into individual messages for each registered side channel.
///
/// A dictionary of channel type to channel.
/// The byte array of data received from Python.
internal static void ProcessSideChannelData(Dictionary sideChannels, byte[] dataReceived)
{
while (s_CachedMessages.Count != 0)
{
var cachedMessage = s_CachedMessages.Dequeue();
if (sideChannels.ContainsKey(cachedMessage.ChannelId))
{
sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message);
}
else
{
Debug.Log(string.Format(
"Unknown side channel data received. Channel Id is "
+ ": {0}", cachedMessage.ChannelId));
}
}
if (dataReceived.Length == 0)
{
return;
}
using (var memStream = new MemoryStream(dataReceived))
{
using (var binaryReader = new BinaryReader(memStream))
{
while (memStream.Position < memStream.Length)
{
Guid channelId = Guid.Empty;
byte[] message = null;
try
{
channelId = new Guid(binaryReader.ReadBytes(16));
var messageLength = binaryReader.ReadInt32();
message = binaryReader.ReadBytes(messageLength);
}
catch (Exception ex)
{
throw new UnityAgentsException(
"There was a problem reading a message in a SideChannel. Please make sure the " +
"version of MLAgents in Unity is compatible with the Python version. Original error : "
+ ex.Message);
}
if (sideChannels.ContainsKey(channelId))
{
sideChannels[channelId].ProcessMessage(message);
}
else
{
// Don't recognize this ID, but cache it in case the SideChannel that can handle
// it is registered before the next call to ProcessSideChannelData.
s_CachedMessages.Enqueue(new CachedSideChannelMessage
{
ChannelId = channelId,
Message = message
});
}
}
}
}
}
}
///
/// Deprecated, use instead.
///
[Obsolete("Use SideChannelManager instead.")]
public static class SideChannelsManager
{
///
/// Deprecated, use instead.
///
///
public static void RegisterSideChannel(SideChannel sideChannel)
{
SideChannelManager.RegisterSideChannel(sideChannel);
}
///
/// Deprecated, use instead.
///
///
public static void UnregisterSideChannel(SideChannel sideChannel)
{
SideChannelManager.UnregisterSideChannel(sideChannel);
}
}
}