浏览代码

using validator to check settings

/sampler-refactor-copy
Andrew Cohen 4 年前
当前提交
95898f37
共有 2 个文件被更改,包括 25 次插入11 次删除
  1. 7
      com.unity.ml-agents/Runtime/Sampler.cs
  2. 29
      ml-agents/mlagents/trainers/settings.py

7
com.unity.ml-agents/Runtime/Sampler.cs


using Unity.MLAgents;
using Unity.MLAgents.Inference.Utils;
using UnityEngine;
using Random=UnityEngine.Random;
namespace Unity.MLAgents
{

public enum SamplerType
internal enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.

/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
public sealed class SamplerFactory
internal sealed class SamplerFactory
{
int m_Seed;

public Func<float> CreateUniformSampler(float min, float max)
{
return () => Random.Range(min, max);
return () => UnityEngine.Random.Range(min, max);
}
public Func<float> CreateGaussianSampler(float mean, float stddev)

29
ml-agents/mlagents/trainers/settings.py


for key, val in param_config.items():
enum_key = ParameterRandomizationType(key)
t = enum_key.to_settings()
d_final[param] = strict_to_cls(val, t).to_float()
d_final[param] = strict_to_cls(val, t).to_float_encoding()
min_value: float = 1.0
min_value: float = attr.ib()
def to_float(self) -> List[float]:
@min_value.default
def _min_value_default(self):
return 1.0
@min_value.validator
def _check_intervals(self, attribute, value):
def to_float_encoding(self) -> List[float]:
return [0.0, self.min_value, self.max_value]

st_dev: float = 1.0
def to_float(self) -> List[float]:
def to_float_encoding(self) -> List[float]:
intervals: List[List[float]] = [[1.0, 1.0]]
intervals: List[List[float]] = attr.ib()
def to_float(self) -> List[float]:
floats: List[float] = []
@intervals.default
def _intervals_default(self):
return [[1.0, 1.0]]
@intervals.validator
def _check_intervals(self, attribute, value):
for interval in self.intervals:
if len(interval) != 2:
raise TrainerConfigError(

raise TrainerConfigError(
f"Minimum value is greater than maximum value in interval {interval}."
)
def to_float_encoding(self) -> List[float]:
floats: List[float] = []
for interval in self.intervals:
floats += interval
return [2.0] + floats

正在加载...
取消
保存