using System.Collections.Generic;
using System.IO;
using System;
using System.Text;
namespace MLAgents
{
public interface IFloatProperties
{
///
/// Sets one of the float properties of the environment. This data will be sent to Python.
///
/// The string identifier of the property.
/// The float value of the property.
void SetProperty(string key, float value);
///
/// Get an Environment property with a default value. If there is a value for this property,
/// it will be returned, otherwise, the default value will be returned.
///
/// The string identifier of the property.
/// The default value of the property.
///
float GetPropertyWithDefault(string key, float defaultValue);
///
/// Registers an action to be performed everytime the property is changed.
///
/// The string identifier of the property.
/// The action that ill be performed. Takes a float as input.
void RegisterCallback(string key, Action action);
///
/// Returns a list of all the string identifiers of the properties currently present.
///
/// The list of string identifiers
IList ListProperties();
}
public class FloatPropertiesChannel : SideChannel, IFloatProperties
{
Dictionary m_FloatProperties = new Dictionary();
Dictionary> m_RegisteredActions = new Dictionary>();
private const string k_FloatPropertiesDefaultId = "60ccf7d0-4f7e-11ea-b238-784f4387d1f7";
public FloatPropertiesChannel(Guid channelId = default(Guid))
{
if (channelId == default(Guid))
{
ChannelId = new Guid(k_FloatPropertiesDefaultId);
}
else{
ChannelId = channelId;
}
}
public override void OnMessageReceived(byte[] data)
{
var kv = DeserializeMessage(data);
m_FloatProperties[kv.Key] = kv.Value;
if (m_RegisteredActions.ContainsKey(kv.Key))
{
m_RegisteredActions[kv.Key].Invoke(kv.Value);
}
}
public void SetProperty(string key, float value)
{
m_FloatProperties[key] = value;
QueueMessageToSend(SerializeMessage(key, value));
if (m_RegisteredActions.ContainsKey(key))
{
m_RegisteredActions[key].Invoke(value);
}
}
public float GetPropertyWithDefault(string key, float defaultValue)
{
if (m_FloatProperties.ContainsKey(key))
{
return m_FloatProperties[key];
}
else
{
return defaultValue;
}
}
public void RegisterCallback(string key, Action action)
{
m_RegisteredActions[key] = action;
}
public IList ListProperties()
{
return new List(m_FloatProperties.Keys);
}
static KeyValuePair DeserializeMessage(byte[] data)
{
using (var memStream = new MemoryStream(data))
{
using (var binaryReader = new BinaryReader(memStream))
{
var keyLength = binaryReader.ReadInt32();
var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength));
var value = binaryReader.ReadSingle();
return new KeyValuePair(key, value);
}
}
}
static byte[] SerializeMessage(string key, float value)
{
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
var stringEncoded = Encoding.ASCII.GetBytes(key);
binaryWriter.Write(stringEncoded.Length);
binaryWriter.Write(stringEncoded);
binaryWriter.Write(value);
return memStream.ToArray();
}
}
}
}
}