using System.Collections.Generic; using System.IO; using System; using System.Text; namespace MLAgents.SideChannels { /// /// Interface for managing a collection of float properties keyed by a string variable. /// public interface IFloatProperties { /// /// Sets one of the float properties of the environment. This data will be sent to Python. /// /// The string identifier of the property. /// The float value of the property. void SetProperty(string key, float value); /// /// Get an Environment property with a default value. If there is a value for this property, /// it will be returned, otherwise, the default value will be returned. /// /// The string identifier of the property. /// The default value of the property. /// float GetPropertyWithDefault(string key, float defaultValue); /// /// Registers an action to be performed everytime the property is changed. /// /// The string identifier of the property. /// The action that ill be performed. Takes a float as input. void RegisterCallback(string key, Action action); /// /// Returns a list of all the string identifiers of the properties currently present. /// /// The list of string identifiers IList ListProperties(); } /// /// Side channel that is comprised of a collection of float variables, represented by /// /// public class FloatPropertiesChannel : SideChannel, IFloatProperties { Dictionary m_FloatProperties = new Dictionary(); Dictionary> m_RegisteredActions = new Dictionary>(); private const string k_FloatPropertiesDefaultId = "60ccf7d0-4f7e-11ea-b238-784f4387d1f7"; /// /// Initializes the side channel with the provided channel ID. /// /// ID for the side channel. public FloatPropertiesChannel(Guid channelId = default(Guid)) { if (channelId == default(Guid)) { ChannelId = new Guid(k_FloatPropertiesDefaultId); } else{ ChannelId = channelId; } } /// public override void OnMessageReceived(byte[] data) { var kv = DeserializeMessage(data); m_FloatProperties[kv.Key] = kv.Value; if (m_RegisteredActions.ContainsKey(kv.Key)) { m_RegisteredActions[kv.Key].Invoke(kv.Value); } } /// public void SetProperty(string key, float value) { m_FloatProperties[key] = value; QueueMessageToSend(SerializeMessage(key, value)); if (m_RegisteredActions.ContainsKey(key)) { m_RegisteredActions[key].Invoke(value); } } /// public float GetPropertyWithDefault(string key, float defaultValue) { if (m_FloatProperties.ContainsKey(key)) { return m_FloatProperties[key]; } else { return defaultValue; } } /// public void RegisterCallback(string key, Action action) { m_RegisteredActions[key] = action; } /// public IList ListProperties() { return new List(m_FloatProperties.Keys); } static KeyValuePair DeserializeMessage(byte[] data) { using (var memStream = new MemoryStream(data)) { using (var binaryReader = new BinaryReader(memStream)) { var keyLength = binaryReader.ReadInt32(); var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength)); var value = binaryReader.ReadSingle(); return new KeyValuePair(key, value); } } } static byte[] SerializeMessage(string key, float value) { using (var memStream = new MemoryStream()) { using (var binaryWriter = new BinaryWriter(memStream)) { var stringEncoded = Encoding.ASCII.GetBytes(key); binaryWriter.Write(stringEncoded.Length); binaryWriter.Write(stringEncoded); binaryWriter.Write(value); return memStream.ToArray(); } } } } }