using System.Collections.Generic;
using System;
using UnityEngine;
namespace Unity.MLAgents.SideChannels
{
///
/// Lists the different data types supported.
///
internal enum EnvironmentDataTypes
{
Float = 0,
Sampler = 1
}
///
/// The types of distributions from which to sample reset parameters.
///
internal enum SamplerType
{
///
/// Samples a reset parameter from a uniform distribution.
///
Uniform = 0,
///
/// Samples a reset parameter from a Gaussian distribution.
///
Gaussian = 1,
///
/// Samples a reset parameter from a Gaussian distribution.
///
MultiRangeUniform = 2
}
///
/// A side channel that manages the environment parameter values from Python. Currently
/// limited to parameters of type float.
///
internal class EnvironmentParametersChannel : SideChannel
{
Dictionary> m_Parameters = new Dictionary>();
Dictionary> m_RegisteredActions =
new Dictionary>();
SamplerFactory m_SamplerFactory = new SamplerFactory();
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400";
///
/// Initializes the side channel. The constructor is internal because only one instance is
/// supported at a time, and is created by the Academy.
///
internal EnvironmentParametersChannel()
{
ChannelId = new Guid(k_EnvParamsId);
}
///
protected override void OnMessageReceived(IncomingMessage msg)
{
var key = msg.ReadString();
var type = msg.ReadInt32();
if ((int)EnvironmentDataTypes.Float == type)
{
var value = msg.ReadFloat32();
m_Parameters[key] = () => value;
Action action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
int samplerType = msg.ReadInt32();
Func sampler = () => 0.0f;
if ((int)SamplerType.Uniform == samplerType)
{
float min = msg.ReadFloat32();
float max = msg.ReadFloat32();
sampler = m_SamplerFactory.CreateUniformSampler(min, max, seed);
}
else if ((int)SamplerType.Gaussian == samplerType)
{
float mean = msg.ReadFloat32();
float stddev = msg.ReadFloat32();
sampler = m_SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
}
else if ((int)SamplerType.MultiRangeUniform == samplerType)
{
IList intervals = msg.ReadFloatList();
sampler = m_SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
m_Parameters[key] = sampler;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
}
///
/// Returns the parameter value associated with the provided key. Returns the default
/// value if one doesn't exist.
///
/// Parameter key.
/// Default value to return.
///
public float GetWithDefault(string key, float defaultValue)
{
Func valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut.Invoke() : defaultValue;
}
///
/// Registers a callback for the associated parameter key. Will overwrite any existing
/// actions for this parameter key.
///
/// The parameter key.
/// The callback.
public void RegisterCallback(string key, Action action)
{
m_RegisteredActions[key] = action;
}
///
/// Returns all parameter keys that have a registered value.
///
///
public IList ListParameters()
{
return new List(m_Parameters.Keys);
}
}
}