浏览代码

seed each sampler individually

/sampler-refactor-copy
Andrew Cohen 4 年前
当前提交
7d52b18f
共有 4 个文件被更改,包括 26 次插入20 次删除
  1. 21
      com.unity.ml-agents/Runtime/Sampler.cs
  2. 9
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  3. 8
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  4. 8
      ml-agents/mlagents/trainers/settings.py

21
com.unity.ml-agents/Runtime/Sampler.cs


/// </summary>
internal sealed class SamplerFactory
{
int m_Seed;
internal SamplerFactory(int seed)
internal SamplerFactory()
m_Seed = seed;
}
/// <summary>

public Func<float> CreateSampler(IList<float> encoding)
public Func<float> CreateSampler(IList<float> encoding, int seed)
return CreateUniformSampler(encoding[1], encoding[2]);
return CreateUniformSampler(encoding[1], encoding[2], seed);
return CreateGaussianSampler(encoding[1], encoding[2]);
return CreateGaussianSampler(encoding[1], encoding[2], seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");

}
public Func<float> CreateUniformSampler(float min, float max)
public Func<float> CreateUniformSampler(float min, float max, int seed)
return () => UnityEngine.Random.Range(min, max);
System.Random distr = new System.Random(seed);
return () => min + (float)distr.NextDouble() * (max - min);
public Func<float> CreateGaussianSampler(float mean, float stddev)
public Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
RandomNormal distr = new RandomNormal(m_Seed, mean, stddev);
RandomNormal distr = new RandomNormal(seed, mean, stddev);
return () => (float)distr.NextDouble();
}
}

9
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


Dictionary<string, Action<float>> m_RegisteredActions =
new Dictionary<string, Action<float>>();
SamplerFactory m_SamplerFactory = new SamplerFactory(1);
SamplerFactory m_SamplerFactory = new SamplerFactory();
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400";

}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
if (seed == -1)
{
seed = UnityEngine.Random.Range(0, 10000);
}
m_Parameters[key] = m_SamplerFactory.CreateSampler(encoding);
m_Parameters[key] = m_SamplerFactory.CreateSampler(encoding, seed);
}
else
{

8
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
# length of list
msg.write_int32(len(values))
for value in values:
# Write seed
msg.write_int32(int(values[0]))
msg.write_int32(len(values[1:]))
# Sampler encoding
for value in values[1:]:
msg.write_float32(value)
super().queue_message_to_send(msg)

8
ml-agents/mlagents/trainers/settings.py


@attr.s(auto_attribs=True)
class ParameterRandomizationSettings(abc.ABC):
seed: int = parser.get_default("seed")
@staticmethod
def structure(d: Mapping, t: type) -> Any:
"""

def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by the min and max values"
return [0.0, self.min_value, self.max_value]
return [self.seed, 0.0, self.min_value, self.max_value]
@attr.s(auto_attribs=True)

def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by the mean and standard deviation"
return [1.0, self.mean, self.st_dev]
return [self.seed, 1.0, self.mean, self.st_dev]
@attr.s(auto_attribs=True)

floats: List[float] = []
for interval in self.intervals:
floats += interval
return [2.0] + floats
return [self.seed, 2.0] + floats
@attr.s(auto_attribs=True)

正在加载...
取消
保存