5 年前
共有 22 个文件被更改,包括 672 次插入 和 454 次删除
using System; |
using System.Collections.Generic; |
using Unity.MLAgents.Inference.Utils; |
using UnityEngine; |
using Random=System.Random; |
namespace Unity.MLAgents |
{ |
/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
internal static class SamplerFactory |
{ |
public static Func<float> CreateUniformSampler(float min, float max, int seed) |
{ |
Random distr = new Random(seed); |
return () => min + (float)distr.NextDouble() * (max - min); |
} |
public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) |
{ |
RandomNormal distr = new RandomNormal(seed, mean, stddev); |
return () => (float)distr.NextDouble(); |
} |
public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) |
{ |
Random distr = new Random(seed); |
// Will be used to normalize intervalFuncs
float sumIntervalSizes = 0; |
//The number of intervals
int numIntervals = (int)(intervals.Count/2); |
// List that will store interval lengths
float[] intervalSizes = new float[numIntervals]; |
// List that will store uniform distributions
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; |
// Collect all intervals and store as uniform distrus
// Collect all interval sizes
for(int i = 0; i < numIntervals; i++) |
{ |
var min = intervals[2 * i]; |
var max = intervals[2 * i + 1]; |
var intervalSize = max - min; |
sumIntervalSizes += intervalSize; |
intervalSizes[i] = intervalSize; |
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; |
} |
// Normalize interval lengths
for(int i = 0; i < numIntervals; i++) |
{ |
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; |
} |
// Build cmf for intervals
for(int i = 1; i < numIntervals; i++) |
{ |
intervalSizes[i] += intervalSizes[i - 1]; |
} |
Multinomial intervalDistr = new Multinomial(seed + 1); |
float MultiRange() |
{ |
int sampledInterval = intervalDistr.Sample(intervalSizes); |
return intervalFuncs[sampledInterval].Invoke(); |
} |
return MultiRange; |
} |
} |
} |
fileFormatVersion: 2 |
guid: 39ce0ea5a8b2e47f696f6efc807029f6 |
MonoImporter: |
externalObjects: {} |
serializedVersion: 2 |
defaultReferences: [] |
executionOrder: 0 |
icon: {instanceID: 0} |
userData: |
assetBundleName: |
assetBundleVariant: |
using System; |
using NUnit.Framework; |
using System.IO; |
using System.Collections.Generic; |
using UnityEngine; |
using Unity.MLAgents.SideChannels; |
namespace Unity.MLAgents.Tests |
{ |
public class SamplerTests |
{ |
const int k_Seed = 1337; |
const double k_Epsilon = 0.0001; |
EnvironmentParametersChannel m_Channel; |
public SamplerTests() |
{ |
m_Channel = SideChannelsManager.GetSideChannel<EnvironmentParametersChannel>(); |
// if running test on its own
if (m_Channel == null) |
{ |
m_Channel = new EnvironmentParametersChannel(); |
SideChannelsManager.RegisterSideChannel(m_Channel); |
} |
} |
[Test] |
public void UniformSamplerTest() |
{ |
float min_value = 1.0f; |
float max_value = 2.0f; |
string parameter = "parameter1"; |
using (var outgoingMsg = new OutgoingMessage()) |
{ |
outgoingMsg.WriteString(parameter); |
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1); |
outgoingMsg.WriteInt32(k_Seed); |
outgoingMsg.WriteInt32((int)SamplerType.Uniform); |
outgoingMsg.WriteFloat32(min_value); |
outgoingMsg.WriteFloat32(max_value); |
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
SideChannelsManager.ProcessSideChannelData(message); |
} |
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
} |
[Test] |
public void GaussianSamplerTest() |
{ |
float mean = 3.0f; |
float stddev = 0.2f; |
string parameter = "parameter2"; |
using (var outgoingMsg = new OutgoingMessage()) |
{ |
outgoingMsg.WriteString(parameter); |
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1); |
outgoingMsg.WriteInt32(k_Seed); |
outgoingMsg.WriteInt32((int)SamplerType.Gaussian); |
outgoingMsg.WriteFloat32(mean); |
outgoingMsg.WriteFloat32(stddev); |
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
SideChannelsManager.ProcessSideChannelData(message); |
} |
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
} |
[Test] |
public void MultiRangeUniformSamplerTest() |
{ |
float[] intervals = new float[4]; |
intervals[0] = 1.2f; |
intervals[1] = 2f; |
intervals[2] = 3.2f; |
intervals[3] = 4.1f; |
string parameter = "parameter3"; |
using (var outgoingMsg = new OutgoingMessage()) |
{ |
outgoingMsg.WriteString(parameter); |
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1); |
outgoingMsg.WriteInt32(k_Seed); |
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); |
outgoingMsg.WriteFloatList(intervals); |
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
SideChannelsManager.ProcessSideChannelData(message); |
} |
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
} |
internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) |
{ |
byte[] message = msg.ToByteArray(); |
using (var memStream = new MemoryStream()) |
{ |
using (var binaryWriter = new BinaryWriter(memStream)) |
{ |
binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
binaryWriter.Write(message.Length); |
binaryWriter.Write(message); |
} |
return memStream.ToArray(); |
} |
} |
} |
} |
fileFormatVersion: 2 |
guid: 7e6609c51018d4132beda8ddedd46d91 |
MonoImporter: |
externalObjects: {} |
serializedVersion: 2 |
defaultReferences: [] |
executionOrder: 0 |
icon: {instanceID: 0} |
userData: |
assetBundleName: |
assetBundleVariant: |
import pytest |
from mlagents.trainers.sampler_class import SamplerManager |
from mlagents.trainers.sampler_class import ( |
UniformSampler, |
MultiRangeUniformSampler, |
GaussianSampler, |
) |
from mlagents.trainers.exception import TrainerError |
def sampler_config_1(): |
return { |
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10}, |
"gravity": { |
"sampler-type": "multirange_uniform", |
"intervals": [[8, 11], [15, 20]], |
}, |
} |
def check_value_in_intervals(val, intervals): |
check_in_bounds = [a <= val <= b for a, b in intervals] |
return any(check_in_bounds) |
def test_sampler_config_1(): |
config = sampler_config_1() |
sampler = SamplerManager(config) |
assert sampler.is_empty() is False |
assert isinstance(sampler.samplers["mass"], UniformSampler) |
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler) |
cur_sample = sampler.sample_all() |
# Check uniform sampler for mass |
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"] |
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"] |
assert config["mass"]["min_value"] <= cur_sample["mass"] |
assert config["mass"]["max_value"] >= cur_sample["mass"] |
# Check multirange_uniform sampler for gravity |
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"] |
assert check_value_in_intervals( |
cur_sample["gravity"], sampler.samplers["gravity"].intervals |
) |
def sampler_config_2(): |
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}} |
def test_sampler_config_2(): |
config = sampler_config_2() |
sampler = SamplerManager(config) |
assert sampler.is_empty() is False |
assert isinstance(sampler.samplers["angle"], GaussianSampler) |
# Check angle gaussian sampler |
assert sampler.samplers["angle"].mean == config["angle"]["mean"] |
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"] |
def test_empty_samplers(): |
empty_sampler = SamplerManager({}) |
assert empty_sampler.is_empty() |
empty_cur_sample = empty_sampler.sample_all() |
assert empty_cur_sample == {} |
none_sampler = SamplerManager(None) |
assert none_sampler.is_empty() |
none_cur_sample = none_sampler.sample_all() |
assert none_cur_sample == {} |
def incorrect_uniform_sampler(): |
# Do not specify required arguments to uniform sampler |
return {"mass": {"sampler-type": "uniform", "min-value": 10}} |
def incorrect_sampler_config(): |
# Do not specify 'sampler-type' key |
return {"mass": {"min-value": 2, "max-value": 30}} |
def test_incorrect_uniform_sampler(): |
config = incorrect_uniform_sampler() |
with pytest.raises(TrainerError): |
SamplerManager(config) |
def test_incorrect_sampler(): |
config = incorrect_sampler_config() |
with pytest.raises(TrainerError): |
SamplerManager(config) |
import numpy as np |
from typing import Union, Optional, Type, List, Dict, Any |
from abc import ABC, abstractmethod |
from mlagents.trainers.exception import SamplerException |
class Sampler(ABC): |
@abstractmethod |
def sample_parameter(self) -> float: |
pass |
class UniformSampler(Sampler): |
""" |
Uniformly draws a single sample in the range [min_value, max_value). |
""" |
def __init__( |
self, |
min_value: Union[int, float], |
max_value: Union[int, float], |
seed: Optional[int] = None, |
): |
""" |
:param min_value: minimum value of the range to be sampled uniformly from |
:param max_value: maximum value of the range to be sampled uniformly from |
:param seed: Random seed used for making draws from the uniform sampler |
""" |
self.min_value = min_value |
self.max_value = max_value |
# Draw from random state to allow for consistent reset parameter draw for a seed |
self.random_state = np.random.RandomState(seed) |
def sample_parameter(self) -> float: |
""" |
Draws and returns a sample from the specified interval |
""" |
return self.random_state.uniform(self.min_value, self.max_value) |
class MultiRangeUniformSampler(Sampler): |
""" |
Draws a single sample uniformly from the intervals provided. The sampler |
first picks an interval based on a weighted selection, with the weights |
assigned to an interval based on its range. After picking the range, |
it proceeds to pick a value uniformly in that range. |
""" |
def __init__( |
self, intervals: List[List[Union[int, float]]], seed: Optional[int] = None |
): |
""" |
:param intervals: List of intervals to draw uniform samples from |
:param seed: Random seed used for making uniform draws from the specified intervals |
""" |
self.intervals = intervals |
# Measure the length of the intervals |
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals] |
cum_interval_length = sum(interval_lengths) |
# Assign weights to an interval proportionate to the interval size |
self.interval_weights = [x / cum_interval_length for x in interval_lengths] |
# Draw from random state to allow for consistent reset parameter draw for a seed |
self.random_state = np.random.RandomState(seed) |
def sample_parameter(self) -> float: |
""" |
Selects an interval to pick and then draws a uniform sample from the picked interval |
""" |
cur_min, cur_max = self.intervals[ |
self.random_state.choice(len(self.intervals), p=self.interval_weights) |
] |
return self.random_state.uniform(cur_min, cur_max) |
class GaussianSampler(Sampler): |
""" |
Draw a single sample value from a normal (gaussian) distribution. |
This sampler is characterized by the mean and the standard deviation. |
""" |
def __init__( |
self, |
mean: Union[float, int], |
st_dev: Union[float, int], |
seed: Optional[int] = None, |
): |
""" |
:param mean: Specifies the mean of the gaussian distribution to draw from |
:param st_dev: Specifies the standard devation of the gaussian distribution to draw from |
:param seed: Random seed used for making gaussian draws from the sample |
""" |
self.mean = mean |
self.st_dev = st_dev |
# Draw from random state to allow for consistent reset parameter draw for a seed |
self.random_state = np.random.RandomState(seed) |
def sample_parameter(self) -> float: |
""" |
Returns a draw from the specified Gaussian distribution |
""" |
return self.random_state.normal(self.mean, self.st_dev) |
class SamplerFactory: |
""" |
Maintain a directory of all samplers available. |
Add new samplers using the register_sampler method. |
""" |
"uniform": UniformSampler, |
"gaussian": GaussianSampler, |
"multirange_uniform": MultiRangeUniformSampler, |
} |
@staticmethod |
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None: |
""" |
Registers the sampe in the Sampler Factory to be used later |
:param name: String name to set as key for the sampler_cls in the factory |
:param sampler_cls: Sampler object to associate to the name in the factory |
""" |
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls |
@staticmethod |
def init_sampler_class( |
name: str, params: Dict[str, Any], seed: Optional[int] = None |
) -> Sampler: |
""" |
Initializes the sampler class associated with the name with the params |
:param name: Name of the sampler in the factory to initialize |
:param params: Parameters associated to the sampler attached to the name |
:param seed: Random seed to be used to set deterministic random draws for the sampler |
""" |
if name not in SamplerFactory.NAME_TO_CLASS: |
raise SamplerException( |
name + " sampler is not registered in the SamplerFactory." |
" Use the register_sample method to register the string" |
" associated to your sampler in the SamplerFactory." |
) |
sampler_cls = SamplerFactory.NAME_TO_CLASS[name] |
params["seed"] = seed |
try: |
return sampler_cls(**params) |
except TypeError: |
raise SamplerException( |
"The sampler class associated to the " + name + " key in the factory " |
"was not provided the required arguments. Please ensure that the sampler " |
"config file consists of the appropriate keys for this sampler class." |
) |
class SamplerManager: |
def __init__( |
self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None |
) -> None: |
""" |