using System.Collections.Generic; using System.IO; using System; using System.Text; namespace MLAgents.SideChannels { /// /// Side channel that is comprised of a collection of float variables, represented by /// /// public class FloatPropertiesChannel : SideChannel { Dictionary m_FloatProperties = new Dictionary(); Dictionary> m_RegisteredActions = new Dictionary>(); private const string k_FloatPropertiesDefaultId = "60ccf7d0-4f7e-11ea-b238-784f4387d1f7"; /// /// Initializes the side channel with the provided channel ID. /// /// ID for the side channel. public FloatPropertiesChannel(Guid channelId = default(Guid)) { if (channelId == default(Guid)) { ChannelId = new Guid(k_FloatPropertiesDefaultId); } else { ChannelId = channelId; } } /// public override void OnMessageReceived(byte[] data) { var kv = DeserializeMessage(data); m_FloatProperties[kv.Key] = kv.Value; if (m_RegisteredActions.ContainsKey(kv.Key)) { m_RegisteredActions[kv.Key].Invoke(kv.Value); } } /// public void SetProperty(string key, float value) { m_FloatProperties[key] = value; QueueMessageToSend(SerializeMessage(key, value)); if (m_RegisteredActions.ContainsKey(key)) { m_RegisteredActions[key].Invoke(value); } } /// public float GetPropertyWithDefault(string key, float defaultValue) { if (m_FloatProperties.ContainsKey(key)) { return m_FloatProperties[key]; } else { return defaultValue; } } /// public void RegisterCallback(string key, Action action) { m_RegisteredActions[key] = action; } /// public IList ListProperties() { return new List(m_FloatProperties.Keys); } static KeyValuePair DeserializeMessage(byte[] data) { using (var memStream = new MemoryStream(data)) { using (var binaryReader = new BinaryReader(memStream)) { var keyLength = binaryReader.ReadInt32(); var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength)); var value = binaryReader.ReadSingle(); return new KeyValuePair(key, value); } } } static byte[] SerializeMessage(string key, float value) { using (var memStream = new MemoryStream()) { using (var binaryWriter = new BinaryWriter(memStream)) { var stringEncoded = Encoding.ASCII.GetBytes(key); binaryWriter.Write(stringEncoded.Length); binaryWriter.Write(stringEncoded); binaryWriter.Write(value); return memStream.ToArray(); } } } } }