using System; using System.Collections.Generic; using UnityEngine; using System.IO; namespace MLAgents.SideChannels { public static class SideChannelUtils { private static Dictionary RegisteredChannels = new Dictionary(); private struct CachedSideChannelMessage { public Guid ChannelId; public byte[] Message; } private static Queue m_CachedMessages = new Queue(); /// /// Registers a side channel to the communicator. The side channel will exchange /// messages with its Python equivalent. /// /// The side channel to be registered. public static void RegisterSideChannel(SideChannel sideChannel) { var channelId = sideChannel.ChannelId; if (RegisteredChannels.ContainsKey(channelId)) { throw new UnityAgentsException(string.Format( "A side channel with type index {0} is already registered. You cannot register multiple " + "side channels of the same id.", channelId)); } // Process any messages that we've already received for this channel ID. var numMessages = m_CachedMessages.Count; for (int i = 0; i < numMessages; i++) { var cachedMessage = m_CachedMessages.Dequeue(); if (channelId == cachedMessage.ChannelId) { using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) { sideChannel.OnMessageReceived(incomingMsg); } } else { m_CachedMessages.Enqueue(cachedMessage); } } RegisteredChannels.Add(channelId, sideChannel); } /// /// Unregisters a side channel from the communicator. /// /// The side channel to be unregistered. public static void UnregisterSideChannel(SideChannel sideChannel) { if (RegisteredChannels.ContainsKey(sideChannel.ChannelId)) { RegisteredChannels.Remove(sideChannel.ChannelId); } } /// /// Unregisters all the side channels from the communicator. /// public static void UnregisterAllSideChannels() { RegisteredChannels = new Dictionary(); } /// /// Returns the SideChannel of Type T if there is one registered, or null if it doesn't. /// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary. /// /// /// public static T GetSideChannel() where T: SideChannel { foreach (var sc in RegisteredChannels.Values) { if (sc.GetType() == typeof(T)) { return (T) sc; } } return null; } /// /// Returns all SideChannels of Type T that are registered. Use if possible, /// as that does not make any memory allocations. /// /// /// public static List GetSideChannels() where T: SideChannel { var output = new List(); foreach (var sc in RegisteredChannels.Values) { if (sc.GetType() == typeof(T)) { output.Add((T) sc); } } return output; } /// /// Grabs the messages that the registered side channels will send to Python at the current step /// into a singe byte array. /// /// internal static byte[] GetSideChannelMessage() { return GetSideChannelMessage(RegisteredChannels); } /// /// Grabs the messages that the registered side channels will send to Python at the current step /// into a singe byte array. /// /// A dictionary of channel type to channel. /// internal static byte[] GetSideChannelMessage(Dictionary sideChannels) { using (var memStream = new MemoryStream()) { using (var binaryWriter = new BinaryWriter(memStream)) { foreach (var sideChannel in sideChannels.Values) { var messageList = sideChannel.MessageQueue; foreach (var message in messageList) { binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); binaryWriter.Write(message.Length); binaryWriter.Write(message); } sideChannel.MessageQueue.Clear(); } return memStream.ToArray(); } } } /// /// Separates the data received from Python into individual messages for each registered side channel. /// /// The byte array of data received from Python. internal static void ProcessSideChannelData(byte[] dataReceived) { ProcessSideChannelData(RegisteredChannels, dataReceived); } /// /// Separates the data received from Python into individual messages for each registered side channel. /// /// A dictionary of channel type to channel. /// The byte array of data received from Python. internal static void ProcessSideChannelData(Dictionary sideChannels, byte[] dataReceived) { while (m_CachedMessages.Count != 0) { var cachedMessage = m_CachedMessages.Dequeue(); if (sideChannels.ContainsKey(cachedMessage.ChannelId)) { using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) { sideChannels[cachedMessage.ChannelId].OnMessageReceived(incomingMsg); } } else { Debug.Log(string.Format( "Unknown side channel data received. Channel Id is " + ": {0}", cachedMessage.ChannelId)); } } if (dataReceived.Length == 0) { return; } using (var memStream = new MemoryStream(dataReceived)) { using (var binaryReader = new BinaryReader(memStream)) { while (memStream.Position < memStream.Length) { Guid channelId = Guid.Empty; byte[] message = null; try { channelId = new Guid(binaryReader.ReadBytes(16)); var messageLength = binaryReader.ReadInt32(); message = binaryReader.ReadBytes(messageLength); } catch (Exception ex) { throw new UnityAgentsException( "There was a problem reading a message in a SideChannel. Please make sure the " + "version of MLAgents in Unity is compatible with the Python version. Original error : " + ex.Message); } if (sideChannels.ContainsKey(channelId)) { using (var incomingMsg = new IncomingMessage(message)) { sideChannels[channelId].OnMessageReceived(incomingMsg); } } else { // Don't recognize this ID, but cache it in case the SideChannel that can handle // it is registered before the next call to ProcessSideChannelData. m_CachedMessages.Enqueue(new CachedSideChannelMessage { ChannelId = channelId, Message = message }); } } } } } } }