GitHub
5 年前
当前提交
5b0a5b9b
共有 22 个文件被更改,包括 672 次插入 和 454 次删除
-
2com.unity.ml-agents/CHANGELOG.md
-
61com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
-
19config/ppo/3DBall_randomize.yaml
-
104docs/Training-ML-Agents.md
-
63ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
-
41ml-agents/mlagents/trainers/learn.py
-
158ml-agents/mlagents/trainers/settings.py
-
6ml-agents/mlagents/trainers/simple_env_manager.py
-
6ml-agents/mlagents/trainers/subprocess_env_manager.py
-
15ml-agents/mlagents/trainers/tests/test_config_conversion.py
-
19ml-agents/mlagents/trainers/tests/test_learn.py
-
86ml-agents/mlagents/trainers/tests/test_settings.py
-
3ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
5ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
28ml-agents/mlagents/trainers/trainer_controller.py
-
20ml-agents/mlagents/trainers/upgrade_config.py
-
70com.unity.ml-agents/Runtime/Sampler.cs
-
11com.unity.ml-agents/Runtime/Sampler.cs.meta
-
109com.unity.ml-agents/Tests/Editor/SamplerTests.cs
-
11com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta
-
96ml-agents/mlagents/trainers/tests/test_sampler_class.py
-
193ml-agents/mlagents/trainers/sampler_class.py
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents.Inference.Utils; |
|||
using UnityEngine; |
|||
using Random=System.Random; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
|
|||
/// </summary>
|
|||
internal static class SamplerFactory |
|||
{ |
|||
|
|||
public static Func<float> CreateUniformSampler(float min, float max, int seed) |
|||
{ |
|||
Random distr = new Random(seed); |
|||
return () => min + (float)distr.NextDouble() * (max - min); |
|||
} |
|||
|
|||
public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) |
|||
{ |
|||
RandomNormal distr = new RandomNormal(seed, mean, stddev); |
|||
return () => (float)distr.NextDouble(); |
|||
} |
|||
|
|||
public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) |
|||
{ |
|||
//RNG
|
|||
Random distr = new Random(seed); |
|||
// Will be used to normalize intervalFuncs
|
|||
float sumIntervalSizes = 0; |
|||
//The number of intervals
|
|||
int numIntervals = (int)(intervals.Count/2); |
|||
// List that will store interval lengths
|
|||
float[] intervalSizes = new float[numIntervals]; |
|||
// List that will store uniform distributions
|
|||
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; |
|||
// Collect all intervals and store as uniform distrus
|
|||
// Collect all interval sizes
|
|||
for(int i = 0; i < numIntervals; i++) |
|||
{ |
|||
var min = intervals[2 * i]; |
|||
var max = intervals[2 * i + 1]; |
|||
var intervalSize = max - min; |
|||
sumIntervalSizes += intervalSize; |
|||
intervalSizes[i] = intervalSize; |
|||
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; |
|||
} |
|||
// Normalize interval lengths
|
|||
for(int i = 0; i < numIntervals; i++) |
|||
{ |
|||
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; |
|||
} |
|||
// Build cmf for intervals
|
|||
for(int i = 1; i < numIntervals; i++) |
|||
{ |
|||
intervalSizes[i] += intervalSizes[i - 1]; |
|||
} |
|||
Multinomial intervalDistr = new Multinomial(seed + 1); |
|||
float MultiRange() |
|||
{ |
|||
int sampledInterval = intervalDistr.Sample(intervalSizes); |
|||
return intervalFuncs[sampledInterval].Invoke(); |
|||
} |
|||
return MultiRange; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 39ce0ea5a8b2e47f696f6efc807029f6 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using System.IO; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.SideChannels; |
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
public class SamplerTests |
|||
{ |
|||
const int k_Seed = 1337; |
|||
const double k_Epsilon = 0.0001; |
|||
EnvironmentParametersChannel m_Channel; |
|||
|
|||
public SamplerTests() |
|||
{ |
|||
m_Channel = SideChannelsManager.GetSideChannel<EnvironmentParametersChannel>(); |
|||
// if running test on its own
|
|||
if (m_Channel == null) |
|||
{ |
|||
m_Channel = new EnvironmentParametersChannel(); |
|||
SideChannelsManager.RegisterSideChannel(m_Channel); |
|||
} |
|||
} |
|||
[Test] |
|||
public void UniformSamplerTest() |
|||
{ |
|||
float min_value = 1.0f; |
|||
float max_value = 2.0f; |
|||
string parameter = "parameter1"; |
|||
using (var outgoingMsg = new OutgoingMessage()) |
|||
{ |
|||
outgoingMsg.WriteString(parameter); |
|||
// 1 indicates this meessage is a Sampler
|
|||
outgoingMsg.WriteInt32(1); |
|||
outgoingMsg.WriteInt32(k_Seed); |
|||
outgoingMsg.WriteInt32((int)SamplerType.Uniform); |
|||
outgoingMsg.WriteFloat32(min_value); |
|||
outgoingMsg.WriteFloat32(max_value); |
|||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
|||
SideChannelsManager.ProcessSideChannelData(message); |
|||
} |
|||
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
} |
|||
|
|||
[Test] |
|||
public void GaussianSamplerTest() |
|||
{ |
|||
float mean = 3.0f; |
|||
float stddev = 0.2f; |
|||
string parameter = "parameter2"; |
|||
using (var outgoingMsg = new OutgoingMessage()) |
|||
{ |
|||
outgoingMsg.WriteString(parameter); |
|||
// 1 indicates this meessage is a Sampler
|
|||
outgoingMsg.WriteInt32(1); |
|||
outgoingMsg.WriteInt32(k_Seed); |
|||
outgoingMsg.WriteInt32((int)SamplerType.Gaussian); |
|||
outgoingMsg.WriteFloat32(mean); |
|||
outgoingMsg.WriteFloat32(stddev); |
|||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
|||
SideChannelsManager.ProcessSideChannelData(message); |
|||
} |
|||
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
} |
|||
|
|||
[Test] |
|||
public void MultiRangeUniformSamplerTest() |
|||
{ |
|||
float[] intervals = new float[4]; |
|||
intervals[0] = 1.2f; |
|||
intervals[1] = 2f; |
|||
intervals[2] = 3.2f; |
|||
intervals[3] = 4.1f; |
|||
string parameter = "parameter3"; |
|||
using (var outgoingMsg = new OutgoingMessage()) |
|||
{ |
|||
outgoingMsg.WriteString(parameter); |
|||
// 1 indicates this meessage is a Sampler
|
|||
outgoingMsg.WriteInt32(1); |
|||
outgoingMsg.WriteInt32(k_Seed); |
|||
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); |
|||
outgoingMsg.WriteFloatList(intervals); |
|||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
|||
SideChannelsManager.ProcessSideChannelData(message); |
|||
} |
|||
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
|||
} |
|||
|
|||
internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) |
|||
{ |
|||
byte[] message = msg.ToByteArray(); |
|||
using (var memStream = new MemoryStream()) |
|||
{ |
|||
using (var binaryWriter = new BinaryWriter(memStream)) |
|||
{ |
|||
binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
|||
binaryWriter.Write(message.Length); |
|||
binaryWriter.Write(message); |
|||
} |
|||
return memStream.ToArray(); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 7e6609c51018d4132beda8ddedd46d91 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import pytest |
|||
|
|||
from mlagents.trainers.sampler_class import SamplerManager |
|||
from mlagents.trainers.sampler_class import ( |
|||
UniformSampler, |
|||
MultiRangeUniformSampler, |
|||
GaussianSampler, |
|||
) |
|||
from mlagents.trainers.exception import TrainerError |
|||
|
|||
|
|||
def sampler_config_1(): |
|||
return { |
|||
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10}, |
|||
"gravity": { |
|||
"sampler-type": "multirange_uniform", |
|||
"intervals": [[8, 11], [15, 20]], |
|||
}, |
|||
} |
|||
|
|||
|
|||
def check_value_in_intervals(val, intervals): |
|||
check_in_bounds = [a <= val <= b for a, b in intervals] |
|||
return any(check_in_bounds) |
|||
|
|||
|
|||
def test_sampler_config_1(): |
|||
config = sampler_config_1() |
|||
sampler = SamplerManager(config) |
|||
|
|||
assert sampler.is_empty() is False |
|||
assert isinstance(sampler.samplers["mass"], UniformSampler) |
|||
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler) |
|||
|
|||
cur_sample = sampler.sample_all() |
|||
|
|||
# Check uniform sampler for mass |
|||
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"] |
|||
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"] |
|||
assert config["mass"]["min_value"] <= cur_sample["mass"] |
|||
assert config["mass"]["max_value"] >= cur_sample["mass"] |
|||
|
|||
# Check multirange_uniform sampler for gravity |
|||
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"] |
|||
assert check_value_in_intervals( |
|||
cur_sample["gravity"], sampler.samplers["gravity"].intervals |
|||
) |
|||
|
|||
|
|||
def sampler_config_2(): |
|||
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}} |
|||
|
|||
|
|||
def test_sampler_config_2(): |
|||
config = sampler_config_2() |
|||
sampler = SamplerManager(config) |
|||
assert sampler.is_empty() is False |
|||
assert isinstance(sampler.samplers["angle"], GaussianSampler) |
|||
|
|||
# Check angle gaussian sampler |
|||
assert sampler.samplers["angle"].mean == config["angle"]["mean"] |
|||
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"] |
|||
|
|||
|
|||
def test_empty_samplers(): |
|||
empty_sampler = SamplerManager({}) |
|||
assert empty_sampler.is_empty() |
|||
empty_cur_sample = empty_sampler.sample_all() |
|||
assert empty_cur_sample == {} |
|||
|
|||
none_sampler = SamplerManager(None) |
|||
assert none_sampler.is_empty() |
|||
none_cur_sample = none_sampler.sample_all() |
|||
assert none_cur_sample == {} |
|||
|
|||
|
|||
def incorrect_uniform_sampler(): |
|||
# Do not specify required arguments to uniform sampler |
|||
return {"mass": {"sampler-type": "uniform", "min-value": 10}} |
|||
|
|||
|
|||
def incorrect_sampler_config(): |
|||
# Do not specify 'sampler-type' key |
|||
return {"mass": {"min-value": 2, "max-value": 30}} |
|||
|
|||
|
|||
def test_incorrect_uniform_sampler(): |
|||
config = incorrect_uniform_sampler() |
|||
with pytest.raises(TrainerError): |
|||
SamplerManager(config) |
|||
|
|||
|
|||
def test_incorrect_sampler(): |
|||
config = incorrect_sampler_config() |
|||
with pytest.raises(TrainerError): |
|||
SamplerManager(config) |
|
|||
import numpy as np |
|||
from typing import Union, Optional, Type, List, Dict, Any |
|||
from abc import ABC, abstractmethod |
|||
|
|||
from mlagents.trainers.exception import SamplerException |
|||
|
|||
|
|||
class Sampler(ABC): |
|||
@abstractmethod |
|||
def sample_parameter(self) -> float: |
|||
pass |
|||
|
|||
|
|||
class UniformSampler(Sampler): |
|||
""" |
|||
Uniformly draws a single sample in the range [min_value, max_value). |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
min_value: Union[int, float], |
|||
max_value: Union[int, float], |
|||
seed: Optional[int] = None, |
|||
): |
|||
""" |
|||
:param min_value: minimum value of the range to be sampled uniformly from |
|||
:param max_value: maximum value of the range to be sampled uniformly from |
|||
:param seed: Random seed used for making draws from the uniform sampler |
|||
""" |
|||
self.min_value = min_value |
|||
self.max_value = max_value |
|||
# Draw from random state to allow for consistent reset parameter draw for a seed |
|||
self.random_state = np.random.RandomState(seed) |
|||
|
|||
def sample_parameter(self) -> float: |
|||
""" |
|||
Draws and returns a sample from the specified interval |
|||
""" |
|||
return self.random_state.uniform(self.min_value, self.max_value) |
|||
|
|||
|
|||
class MultiRangeUniformSampler(Sampler): |
|||
""" |
|||
Draws a single sample uniformly from the intervals provided. The sampler |
|||
first picks an interval based on a weighted selection, with the weights |
|||
assigned to an interval based on its range. After picking the range, |
|||
it proceeds to pick a value uniformly in that range. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, intervals: List[List[Union[int, float]]], seed: Optional[int] = None |
|||
): |
|||
""" |
|||
:param intervals: List of intervals to draw uniform samples from |
|||
:param seed: Random seed used for making uniform draws from the specified intervals |
|||
""" |
|||
self.intervals = intervals |
|||
# Measure the length of the intervals |
|||
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals] |
|||
cum_interval_length = sum(interval_lengths) |
|||
# Assign weights to an interval proportionate to the interval size |
|||
self.interval_weights = [x / cum_interval_length for x in interval_lengths] |
|||
# Draw from random state to allow for consistent reset parameter draw for a seed |
|||
self.random_state = np.random.RandomState(seed) |
|||
|
|||
def sample_parameter(self) -> float: |
|||
""" |
|||
Selects an interval to pick and then draws a uniform sample from the picked interval |
|||
""" |
|||
cur_min, cur_max = self.intervals[ |
|||
self.random_state.choice(len(self.intervals), p=self.interval_weights) |
|||
] |
|||
return self.random_state.uniform(cur_min, cur_max) |
|||
|
|||
|
|||
class GaussianSampler(Sampler): |
|||
""" |
|||
Draw a single sample value from a normal (gaussian) distribution. |
|||
This sampler is characterized by the mean and the standard deviation. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
mean: Union[float, int], |
|||
st_dev: Union[float, int], |
|||
seed: Optional[int] = None, |
|||
): |
|||
""" |
|||
:param mean: Specifies the mean of the gaussian distribution to draw from |
|||
:param st_dev: Specifies the standard devation of the gaussian distribution to draw from |
|||
:param seed: Random seed used for making gaussian draws from the sample |
|||
""" |
|||
self.mean = mean |
|||
self.st_dev = st_dev |
|||
# Draw from random state to allow for consistent reset parameter draw for a seed |
|||
self.random_state = np.random.RandomState(seed) |
|||
|
|||
def sample_parameter(self) -> float: |
|||
""" |
|||
Returns a draw from the specified Gaussian distribution |
|||
""" |
|||
return self.random_state.normal(self.mean, self.st_dev) |
|||
|
|||
|
|||
class SamplerFactory: |
|||
""" |
|||
Maintain a directory of all samplers available. |
|||
Add new samplers using the register_sampler method. |
|||
""" |
|||
|
|||
NAME_TO_CLASS = { |
|||
"uniform": UniformSampler, |
|||
"gaussian": GaussianSampler, |
|||
"multirange_uniform": MultiRangeUniformSampler, |
|||
} |
|||
|
|||
@staticmethod |
|||
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None: |
|||
""" |
|||
Registers the sampe in the Sampler Factory to be used later |
|||
:param name: String name to set as key for the sampler_cls in the factory |
|||
:param sampler_cls: Sampler object to associate to the name in the factory |
|||
""" |
|||
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls |
|||
|
|||
@staticmethod |
|||
def init_sampler_class( |
|||
name: str, params: Dict[str, Any], seed: Optional[int] = None |
|||
) -> Sampler: |
|||
""" |
|||
Initializes the sampler class associated with the name with the params |
|||
:param name: Name of the sampler in the factory to initialize |
|||
:param params: Parameters associated to the sampler attached to the name |
|||
:param seed: Random seed to be used to set deterministic random draws for the sampler |
|||
""" |
|||
if name not in SamplerFactory.NAME_TO_CLASS: |
|||
raise SamplerException( |
|||
name + " sampler is not registered in the SamplerFactory." |
|||
" Use the register_sample method to register the string" |
|||
" associated to your sampler in the SamplerFactory." |
|||
) |
|||
sampler_cls = SamplerFactory.NAME_TO_CLASS[name] |
|||
params["seed"] = seed |
|||
try: |
|||
return sampler_cls(**params) |
|||
except TypeError: |
|||
raise SamplerException( |
|||
"The sampler class associated to the " + name + " key in the factory " |
|||
"was not provided the required arguments. Please ensure that the sampler " |
|||
"config file consists of the appropriate keys for this sampler class." |
|||
) |
|||
|
|||
|
|||
class SamplerManager: |
|||
def __init__( |
|||
self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None |
|||
) -> None: |
|||
""" |
|||
:param reset_param_dict: Arguments needed for initializing the samplers |
|||
:param seed: Random seed to be used for drawing samples from the samplers |
|||
""" |
|||
self.reset_param_dict = reset_param_dict if reset_param_dict else {} |
|||
assert isinstance(self.reset_param_dict, dict) |
|||
self.samplers: Dict[str, Sampler] = {} |
|||
for param_name, cur_param_dict in self.reset_param_dict.items(): |
|||
if "sampler-type" not in cur_param_dict: |
|||
raise SamplerException( |
|||
"'sampler_type' argument hasn't been supplied for the {0} parameter".format( |
|||
param_name |
|||
) |
|||
) |
|||
sampler_name = cur_param_dict.pop("sampler-type") |
|||
param_sampler = SamplerFactory.init_sampler_class( |
|||
sampler_name, cur_param_dict, seed |
|||
) |
|||
|
|||
self.samplers[param_name] = param_sampler |
|||
|
|||
def is_empty(self) -> bool: |
|||
""" |
|||
Check for if sampler_manager is empty. |
|||
""" |
|||
return not bool(self.samplers) |
|||
|
|||
def sample_all(self) -> Dict[str, float]: |
|||
""" |
|||
Loop over all samplers and draw a sample from each one for generating |
|||
next set of reset parameter values. |
|||
""" |
|||
res = {} |
|||
for param_name, param_sampler in list(self.samplers.items()): |
|||
res[param_name] = param_sampler.sample_parameter() |
|||
return res |
撰写
预览
正在加载...
取消
保存
Reference in new issue