浏览代码

added multirange uniform distr

/sampler-refactor-copy
Andrew Cohen 4 年前
当前提交
6a1dccad
共有 2 个文件被更改,包括 59 次插入4 次删除
  1. 61
      com.unity.ml-agents/Runtime/Sampler.cs
  2. 2
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs

61
com.unity.ml-agents/Runtime/Sampler.cs


using System;
using System.Linq;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Inference.Utils;

/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1
Gaussian = 1,
/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
MultiRangeUniform = 2
}
/// <summary>

{
return CreateGaussianSampler(encoding[1], encoding[2], seed);
}
else if ((int)encoding[0] == (int)SamplerType.MultiRangeUniform)
{
return CreateMultiRangeUniformSampler(encoding, seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
return () => 0;

public Func<float> CreateUniformSampler(float min, float max, int seed)
internal Func<float> CreateUniformSampler(float min, float max, int seed)
public Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
internal Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
}
internal Func<float> CreateMultiRangeUniformSampler(IList<float> encoding, int seed)
{
//RNG
System.Random distr = new System.Random(seed);
// Skip type of distribution since already checked to get into this function
var sampler_encoding = encoding.Skip(1);
// Will be used to normalize intervals
float sum_interval_sizes = 0;
//The number of intervals
int num_intervals = (int)(sampler_encoding.Count()/2);
// List that will store interval lengths
float[] interval_sizes = new float[num_intervals];
// List that will store uniform distributions
IList<Func<float>> intervals = new Func<float>[num_intervals];
// Collect all intervals and store as uniform distrus
// Collect all interval sizes
for(int i = 0; i < num_intervals; i++)
{
var min = sampler_encoding.ElementAt(2 * i);
var max = sampler_encoding.ElementAt(2 * i + 1);
var interval_size = max - min;
sum_interval_sizes += interval_size;
interval_sizes[i] = interval_size;
intervals[i] = () => min + (float)distr.NextDouble() * interval_size;
}
// Normalize interval lengths
for(int i = 0; i < num_intervals; i++)
{
interval_sizes[i] = interval_sizes[i] / sum_interval_sizes;
}
// Build cmf for intervals
for(int i = 1; i < num_intervals; i++)
{
interval_sizes[i] += interval_sizes[i - 1];
}
Multinomial intervalDistr = new Multinomial(seed);
float MultiRange()
{
int sampledInterval = intervalDistr.Sample(interval_sizes);
return intervals[sampledInterval].Invoke();
}
return MultiRange;
}
}
}

2
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


{
Func<float> valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut() : defaultValue;
return hasKey ? valueOut.Invoke() : defaultValue;
}
/// <summary>

正在加载...
取消
保存