浏览代码

some cleanups/ interval error checking

/sampler-refactor-copy
Andrew Cohen 5 年前
当前提交
e386b829
共有 5 个文件被更改,包括 26 次插入44 次删除
  1. 5
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  2. 5
      com.unity.ml-agents/Runtime/EnvironmentParameters.cs
  3. 38
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  4. 4
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  5. 18
      ml-agents/mlagents/trainers/settings.py

5
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
//m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
m_BallRb.mass = m_ResetParams.Sample("mass", 1.0f);
var scale = m_ResetParams.Sample("scale", 1.0f);
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

5
com.unity.ml-agents/Runtime/EnvironmentParameters.cs


return m_Channel.GetWithDefault(key, defaultValue);
}
public float Sample(string key, float defaultValue)
{
return m_Channel.Sample(key, defaultValue);
}
/// <summary>
/// Registers a callback action for the provided parameter key. Will overwrite any
/// existing action for that parameter. The callback will be called whenever the parameter

38
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


/// </summary>
internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, float> m_Parameters = new Dictionary<string, float>();
Dictionary<string, Func<float>> m_Samplers = new Dictionary<string, Func<float>>();
Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>();
Dictionary<string, Action<float>> m_RegisteredActions =
new Dictionary<string, Action<float>>();

{
var value = msg.ReadFloat32();
m_Parameters[key] = value;
m_Parameters[key] = () => value;
Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);

{
var encoding = msg.ReadFloatList();
m_Samplers[key] = m_SamplerFactory.CreateSampler(encoding);
//var samplerType = msg.ReadFloat32();
//var statOne = msg.ReadFloat32();
//var statTwo = msg.ReadFloat32();
//m_Parameters[key+"-sampler-type"] = samplerType;
//m_Parameters[key+"-min"] = statOne;
//m_Parameters[key+"-max"] = statTwo;
m_Parameters[key] = m_SamplerFactory.CreateSampler(encoding);
}
else
{

/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
float valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
}
public float Sample(string key, float defaultValue)
{
bool hasKey = m_Samplers.TryGetValue(key, out valueOut);
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
/// <summary>
/// Returns the parameter value associated with the provided key. Returns the default
/// value if one doesn't exist.
/// </summary>
/// <param name="key">Parameter key.</param>
/// <param name="defaultValue">Default value to return.</param>
/// <returns></returns>
public float GetListWithDefault(string key, float defaultValue)
{
float valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
}
/// <summary>
/// Registers a callback for the associated parameter key. Will overwrite any existing
/// actions for this parameter key.

4
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


def set_sampler_parameters(self, key: str, values: List[float]) -> None:
"""
Sets a float environment parameter in the Unity Environment.
Sets a float encoding of an environment parameter sampelr.
:param value: The float value of the parameter.
:param value: The float encoding of the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)

18
ml-agents/mlagents/trainers/settings.py


max_value: float = 1.0
def to_float(self) -> List[float]:
if self.min_value > self.max_value:
raise TrainerConfigError(
"Minimum value is greater than maximum value in uniform sampler."
)
return [0.0, self.min_value, self.max_value]

intervals: List[List[float]] = [[1.0, 1.0]]
def to_float(self) -> List[float]:
return [2.0] + [val for interval in self.intervals for val in interval]
floats: List[float] = []
for interval in self.intervals:
if len(interval) != 2:
raise TrainerConfigError(
f"The sampling interval {interval} must contain exactly two values."
)
[min_value, max_value] = interval
if min_value > max_value:
raise TrainerConfigError(
f"Minimum value is greater than maximum value in interval {interval}."
)
floats += interval
return [2.0] + floats
@attr.s(auto_attribs=True)

正在加载...
取消
保存