using System.Collections.Generic; using System; using UnityEngine; namespace Unity.MLAgents.SideChannels { /// /// Lists the different data types supported. /// internal enum EnvironmentDataTypes { Float = 0, Sampler = 1 } /// /// The types of distributions from which to sample reset parameters. /// internal enum SamplerType { /// /// Samples a reset parameter from a uniform distribution. /// Uniform = 0, /// /// Samples a reset parameter from a Gaussian distribution. /// Gaussian = 1, /// /// Samples a reset parameter from a Gaussian distribution. /// MultiRangeUniform = 2 } /// /// A side channel that manages the environment parameter values from Python. Currently /// limited to parameters of type float. /// internal class EnvironmentParametersChannel : SideChannel { Dictionary> m_Parameters = new Dictionary>(); Dictionary> m_RegisteredActions = new Dictionary>(); SamplerFactory m_SamplerFactory = new SamplerFactory(); const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400"; /// /// Initializes the side channel. The constructor is internal because only one instance is /// supported at a time, and is created by the Academy. /// internal EnvironmentParametersChannel() { ChannelId = new Guid(k_EnvParamsId); } /// protected override void OnMessageReceived(IncomingMessage msg) { var key = msg.ReadString(); var type = msg.ReadInt32(); if ((int)EnvironmentDataTypes.Float == type) { var value = msg.ReadFloat32(); m_Parameters[key] = () => value; Action action; m_RegisteredActions.TryGetValue(key, out action); action?.Invoke(value); } else if ((int)EnvironmentDataTypes.Sampler == type) { int seed = msg.ReadInt32(); int samplerType = msg.ReadInt32(); Func sampler = () => 0.0f; if ((int)SamplerType.Uniform == samplerType) { float min = msg.ReadFloat32(); float max = msg.ReadFloat32(); sampler = m_SamplerFactory.CreateUniformSampler(min, max, seed); } else if ((int)SamplerType.Gaussian == samplerType) { float mean = msg.ReadFloat32(); float stddev = msg.ReadFloat32(); sampler = m_SamplerFactory.CreateGaussianSampler(mean, stddev, seed); } else if ((int)SamplerType.MultiRangeUniform == samplerType) { IList intervals = msg.ReadFloatList(); sampler = m_SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed); } else{ Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); } m_Parameters[key] = sampler; } else { Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); } } /// /// Returns the parameter value associated with the provided key. Returns the default /// value if one doesn't exist. /// /// Parameter key. /// Default value to return. /// public float GetWithDefault(string key, float defaultValue) { Func valueOut; bool hasKey = m_Parameters.TryGetValue(key, out valueOut); return hasKey ? valueOut.Invoke() : defaultValue; } /// /// Registers a callback for the associated parameter key. Will overwrite any existing /// actions for this parameter key. /// /// The parameter key. /// The callback. public void RegisterCallback(string key, Action action) { m_RegisteredActions[key] = action; } /// /// Returns all parameter keys that have a registered value. /// /// public IList ListParameters() { return new List(m_Parameters.Keys); } } }