比较提交

...
此合并请求有变更与目标分支冲突。
/config/ppo/3DBall_randomize.yaml
/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
/docs/Training-ML-Agents.md
/ml-agents/mlagents/trainers/settings.py
/ml-agents/mlagents/trainers/learn.py
/ml-agents/mlagents/trainers/trainer_controller.py
/ml-agents/mlagents/trainers/subprocess_env_manager.py
/ml-agents/mlagents/trainers/tests/test_settings.py
/ml-agents/mlagents/trainers/tests/test_learn.py
/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
/com.unity.ml-agents/Runtime/Sampler.cs
/ml-agents/mlagents/trainers/tests/test_simple_rl.py

12 次代码提交

共有 16 个文件被更改,包括 429 次插入380 次删除
  1. 17
      config/ppo/3DBall_randomize.yaml
  2. 63
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  3. 64
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  4. 28
      ml-agents/mlagents/trainers/trainer_controller.py
  5. 41
      ml-agents/mlagents/trainers/learn.py
  6. 106
      ml-agents/mlagents/trainers/settings.py
  7. 20
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  8. 63
      ml-agents/mlagents/trainers/tests/test_settings.py
  9. 18
      ml-agents/mlagents/trainers/tests/test_learn.py
  10. 3
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  11. 5
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  12. 7
      docs/Training-ML-Agents.md
  13. 11
      com.unity.ml-agents/Runtime/Sampler.cs.meta
  14. 74
      com.unity.ml-agents/Runtime/Sampler.cs
  15. 96
      ml-agents/mlagents/trainers/tests/test_sampler_class.py
  16. 193
      ml-agents/mlagents/trainers/sampler_class.py

17
config/ppo/3DBall_randomize.yaml


threaded: true
parameter_randomization:
resampling-interval: 5000
sampler-type: uniform
min_value: 0.5
max_value: 10
gravity:
sampler-type: uniform
min_value: 7
max_value: 12
uniform:
min_value: 0.5
max_value: 10
sampler-type: uniform
min_value: 0.75
max_value: 3
uniform:
min_value: 0.75
max_value: 3

63
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


/// </summary>
internal enum EnvironmentDataTypes
{
Float = 0
Float = 0,
Sampler = 1
}
/// <summary>
/// The types of distributions from which to sample reset parameters.
/// </summary>
internal enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.
/// </summary>
Uniform = 0,
/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1,
/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
MultiRangeUniform = 2
}
/// <summary>

internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, float> m_Parameters = new Dictionary<string, float>();
Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>();
SamplerFactory m_SamplerFactory = new SamplerFactory();
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400";

{
var value = msg.ReadFloat32();
m_Parameters[key] = value;
m_Parameters[key] = () => value;
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
int samplerType = msg.ReadInt32();
Func<float> sampler = () => 0.0f;
if ((int)SamplerType.Uniform == samplerType)
{
float min = msg.ReadFloat32();
float max = msg.ReadFloat32();
sampler = m_SamplerFactory.CreateUniformSampler(min, max, seed);
}
else if ((int)SamplerType.Gaussian == samplerType)
{
float mean = msg.ReadFloat32();
float stddev = msg.ReadFloat32();
sampler = m_SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
}
else if ((int)SamplerType.MultiRangeUniform == samplerType)
{
IList<float> intervals = msg.ReadFloatList();
sampler = m_SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
m_Parameters[key] = sampler;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");

/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
float valueOut;
Func<float> valueOut;
return hasKey ? valueOut : defaultValue;
return hasKey ? valueOut.Invoke() : defaultValue;
}
/// <summary>

64
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


from mlagents_envs.exception import UnityCommunicationException
import uuid
from enum import IntEnum
from typing import List
class EnvironmentParametersChannel(SideChannel):

class EnvironmentDataTypes(IntEnum):
FLOAT = 0
SAMPLER = 1
class SamplerTypes(IntEnum):
UNIFORM = 0
GAUSSIAN = 1
MULTIRANGEUNIFORM = 2
def __init__(self) -> None:
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))

msg.write_int32(self.EnvironmentDataTypes.FLOAT)
msg.write_float32(value)
super().queue_message_to_send(msg)
def set_uniform_sampler_parameters(
self, key: str, min_value: float, max_value: float, seed: int
) -> None:
"""
Sets a uniform environment parameter sampler.
:param key: The string identifier of the parameter.
:param min_value: The minimum of the sampling distribution.
:param max_value: The maximum of the sampling distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.UNIFORM)
msg.write_float32(min_value)
msg.write_float32(max_value)
super().queue_message_to_send(msg)
def set_gaussian_sampler_parameters(
self, key: str, mean: float, st_dev: float, seed: int
) -> None:
"""
Sets a gaussian environment parameter sampler.
:param key: The string identifier of the parameter.
:param mean: The mean of the sampling distribution.
:param st_dev: The standard deviation of the sampling distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.GAUSSIAN)
msg.write_float32(mean)
msg.write_float32(st_dev)
super().queue_message_to_send(msg)
def set_multirangeuniform_sampler_parameters(
self, key: str, intervals: List[float], seed: int
) -> None:
"""
Sets a gaussian environment parameter sampler.
:param key: The string identifier of the parameter.
:param intervals: The min and max that define each uniform distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM)
msg.write_int32(len(intervals))
for value in intervals:
msg.write_float32(value)
super().queue_message_to_send(msg)

28
ml-agents/mlagents/trainers/trainer_controller.py


UnityCommunicationException,
UnityCommunicatorStoppedException,
)
from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import (
hierarchical_timer,
timed,

meta_curriculum: Optional[MetaCurriculum],
train: bool,
training_seed: int,
sampler_manager: SamplerManager,
resampling_interval: Optional[int],
):
"""
:param output_path: Path to save the model.

:param train: Whether to train model, or only run inference.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
:param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters.
:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
:param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
"""
self.trainers: Dict[str, Trainer] = {}

self.run_id = run_id
self.train_model = train
self.meta_curriculum = meta_curriculum
self.sampler_manager = sampler_manager
self.resampling_interval = resampling_interval
self.ghost_controller = self.trainer_factory.ghost_controller
self.trainer_threads: List[threading.Thread] = []

A Data structure corresponding to the initial reset state of the
environment.
"""
sampled_reset_param = self.sampler_manager.sample_all()
sampled_reset_param.update(new_meta_curriculum_config)
env.reset(config=sampled_reset_param)
env.reset(config=new_meta_curriculum_config)
def _not_done_training(self) -> bool:
return (

def start_learning(self, env_manager: EnvManager) -> None:
self._create_output_path(self.output_path)
tf.reset_default_graph()
global_step = 0
last_brain_behavior_ids: Set[str] = set()
try:
# Initial reset

last_brain_behavior_ids = external_brain_behavior_ids
n_steps = self.advance(env_manager)
for _ in range(n_steps):
global_step += 1
self.reset_env_if_ready(env_manager, global_step)
self.reset_env_if_ready(env_manager)
# Stop advancing trainers
self.join_threads()
except (

if changed:
self.trainers[brain_name].reward_buffer.clear()
def reset_env_if_ready(self, env: EnvManager, steps: int) -> None:
def reset_env_if_ready(self, env: EnvManager) -> None:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {

# If any lessons were incremented or the environment is
# ready to be reset
meta_curriculum_reset = any(lessons_incremented.values())
# Check if we are performing generalization training and we have finished the
# specified number of steps for the lesson
generalization_reset = (
not self.sampler_manager.is_empty()
and (steps != 0)
and (self.resampling_interval)
and (steps % self.resampling_interval == 0)
)
# If ghost trainer swapped teams
if meta_curriculum_reset or generalization_reset or ghost_controller_reset:
if meta_curriculum_reset or ghost_controller_reset:
self.end_trainer_episodes(env, lessons_incremented)
@timed

41
ml-agents/mlagents/trainers/learn.py


)
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.exception import SamplerException
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.training_status import GlobalTrainingStatus
from mlagents_envs.base_env import BaseEnv

maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum, env_manager, restore=checkpoint_settings.resume
)
sampler_manager, resampling_interval = create_sampler_manager(
options.parameter_randomization, run_seed
)
maybe_add_samplers(options.parameter_randomization, env_manager, run_seed)
trainer_factory = TrainerFactory(
options.behaviors,
checkpoint_settings.run_id,

maybe_meta_curriculum,
not checkpoint_settings.inference,
run_seed,
sampler_manager,
resampling_interval,
)
# Begin training

)
def create_sampler_manager(sampler_config, run_seed=None):
resample_interval = None
def maybe_add_samplers(
sampler_config: Optional[Dict], env: SubprocessEnvManager, run_seed: int
) -> None:
"""
Adds samplers to env if sampler config provided and sets seed if not configured.
:param sampler_config: validated dict of sampler configs. None if not included.
:param env: env manager to pass samplers via reset
:param run_seed: Random seed used for training.
"""
if "resampling-interval" in sampler_config:
# Filter arguments that do not exist in the environment
resample_interval = sampler_config.pop("resampling-interval")
if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
raise SamplerException(
"Specified resampling-interval is not valid. Please provide"
" a positive integer value for resampling-interval"
)
else:
raise SamplerException(
"Resampling interval was not specified in the sampler file."
" Please specify it with the 'resampling-interval' key in the sampler config file."
)
sampler_manager = SamplerManager(sampler_config, run_seed)
return sampler_manager, resample_interval
# If the seed is not specified in yaml, this will grab the run seed
for _, v in sampler_config.items():
if v.seed == -1:
v.seed = run_seed
env.reset(config=sampler_config)
def try_create_meta_curriculum(

106
ml-agents/mlagents/trainers/settings.py


from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.models import ScheduleType, EncoderType
from mlagents_envs import logging_util
logger = logging_util.get_logger(__name__)
def check_and_structure(key: str, value: Any, class_type: type) -> Any:
attr_fields_dict = attr.fields_dict(class_type)

learning_rate: float = 3e-4
class ParameterRandomizationType(Enum):
UNIFORM: str = "uniform"
GAUSSIAN: str = "gaussian"
MULTIRANGEUNIFORM: str = "multirangeuniform"
def to_settings(self) -> type:
_mapping = {
ParameterRandomizationType.UNIFORM: UniformSettings,
ParameterRandomizationType.GAUSSIAN: GaussianSettings,
ParameterRandomizationType.MULTIRANGEUNIFORM: MultiRangeUniformSettings,
}
return _mapping[self]
@attr.s(auto_attribs=True)
class ParameterRandomizationSettings:
seed: int = parser.get_default("seed")
@staticmethod
def structure(d: Mapping, t: type) -> Any:
"""
Helper method to structure a Dict of ParameterRandomizationSettings class. Meant to be registered with
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle
the special Enum selection of ParameterRandomizationSettings classes.
"""
if not isinstance(d, Mapping):
raise TrainerConfigError(
f"Unsupported parameter randomization configuration {d}."
)
d_final: Dict[str, List[float]] = {}
for param, param_config in d.items():
if param == "resampling-interval":
logger.warning(
"The resampling-interval is no longer necessary for parameter randomization. It is being ignored."
)
continue
if not isinstance(param_config, Mapping):
raise TrainerConfigError(
f"Unsupported distribution configuration {param_config}."
)
for key, val in param_config.items():
enum_key = ParameterRandomizationType(key)
t = enum_key.to_settings()
d_final[param] = strict_to_cls(val, t)
return d_final
@attr.s(auto_attribs=True)
class UniformSettings(ParameterRandomizationSettings):
min_value: float = attr.ib()
max_value: float = 1.0
@min_value.default
def _min_value_default(self):
return 1.0
@min_value.validator
def _check_min_value(self, attribute, value):
if self.min_value > self.max_value:
raise TrainerConfigError(
"Minimum value is greater than maximum value in uniform sampler."
)
@attr.s(auto_attribs=True)
class GaussianSettings(ParameterRandomizationSettings):
mean: float = 1.0
st_dev: float = 1.0
@attr.s(auto_attribs=True)
class MultiRangeUniformSettings(ParameterRandomizationSettings):
intervals: List[List[float]] = attr.ib()
@intervals.default
def _intervals_default(self):
return [[1.0, 1.0]]
@intervals.validator
def _check_intervals(self, attribute, value):
for interval in self.intervals:
if len(interval) != 2:
raise TrainerConfigError(
f"The sampling interval {interval} must contain exactly two values."
)
[min_value, max_value] = interval
if min_value > max_value:
raise TrainerConfigError(
f"Minimum value is greater than maximum value in interval {interval}."
)
def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by a flattened list of the interval values"
return [value for interval in self.intervals for value in interval]
@attr.s(auto_attribs=True)
class SelfPlaySettings:
save_steps: int = 20000

)
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings)
engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
parameter_randomization: Optional[Dict] = None
parameter_randomization: Optional[Dict[str, ParameterRandomizationSettings]] = None
curriculum: Optional[Dict[str, CurriculumSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)

cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
cattr.register_structure_hook(EngineSettings, strict_to_cls)
cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
cattr.register_structure_hook(
Dict[str, ParameterRandomizationSettings],
ParameterRandomizationSettings.structure,
)
cattr.register_structure_hook(CurriculumSettings, strict_to_cls)
cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure)
cattr.register_structure_hook(

20
ml-agents/mlagents/trainers/subprocess_env_manager.py


get_timer_root,
)
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.settings import (
UniformSettings,
GaussianSettings,
MultiRangeUniformSettings,
)
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,

_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
elif req.cmd == EnvironmentCommand.RESET:
for k, v in req.payload.items():
env_parameters.set_float_parameter(k, v)
if isinstance(v, float):
env_parameters.set_float_parameter(k, v)
elif isinstance(v, UniformSettings):
env_parameters.set_uniform_sampler_parameters(
k, v.min_value, v.max_value, v.seed
)
elif isinstance(v, GaussianSettings):
env_parameters.set_gaussian_sampler_parameters(
k, v.mean, v.st_dev, v.seed
)
elif isinstance(v, MultiRangeUniformSettings):
env_parameters.set_multirangeuniform_sampler_parameters(
k, v.to_float_encoding(), v.seed
)
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)

63
ml-agents/mlagents/trainers/tests/test_settings.py


RewardSignalType,
RewardSignalSettings,
CuriositySettings,
ParameterRandomizationSettings,
UniformSettings,
GaussianSettings,
MultiRangeUniformSettings,
TrainerType,
strict_to_cls,
)

RewardSignalSettings.structure(
"notadict", Dict[RewardSignalType, RewardSignalSettings]
)
def test_parameter_randomization_structure():
"""
Tests the ParameterRandomizationSettings structure method and all validators.
"""
parameter_randomization_dict = {
"mass": {"uniform": {"min_value": 1.0, "max_value": 2.0}},
"scale": {"gaussian": {"mean": 1.0, "st_dev": 2.0}},
"length": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0, 4.0]]}},
}
parameter_randomization_distributions = ParameterRandomizationSettings.structure(
parameter_randomization_dict, Dict[str, ParameterRandomizationSettings]
)
assert isinstance(parameter_randomization_distributions["mass"], UniformSettings)
assert isinstance(parameter_randomization_distributions["scale"], GaussianSettings)
assert isinstance(
parameter_randomization_distributions["length"], MultiRangeUniformSettings
)
# Check invalid distribution type
invalid_distribution_dict = {"mass": {"beta": {"alpha": 1.0, "beta": 2.0}}}
with pytest.raises(ValueError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in uniform
invalid_distribution_dict = {
"mass": {"uniform": {"min_value": 2.0, "max_value": 1.0}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in multirange
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[2.0, 1.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check multirange has valid intervals
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check non-Dict input
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
"notadict", Dict[str, ParameterRandomizationSettings]
)

18
ml-agents/mlagents/trainers/tests/test_learn.py


from mlagents.trainers.cli_utils import DetectDefault
from mlagents_envs.exception import UnityEnvironmentException
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.settings import UniformSettings
def basic_options(extra_args=None):

MOCK_SAMPLER_CURRICULUM_YAML = """
parameter_randomization:
sampler1: foo
sampler1:
uniform:
min_value: 0.2
curriculum:
behavior1:

@patch("mlagents.trainers.learn.write_run_options")
@patch("mlagents.trainers.learn.handle_existing_directories")
@patch("mlagents.trainers.learn.TrainerFactory")
@patch("mlagents.trainers.learn.SamplerManager")
@patch("mlagents.trainers.learn.SubprocessEnvManager")
@patch("mlagents.trainers.learn.create_environment_factory")
@patch("mlagents.trainers.settings.load_config")

subproc_env_mock,
sampler_manager_mock,
trainer_factory_mock,
handle_dir_mock,
write_run_options_mock,

options = basic_options()
learn.run_training(0, options)
mock_init.assert_called_once_with(
trainer_factory_mock.return_value,
"results/ppo",
"ppo",
None,
True,
0,
sampler_manager_mock.return_value,
None,
trainer_factory_mock.return_value, "results/ppo", "ppo", None, True, 0
)
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")

@patch("builtins.open", new_callable=mock_open, read_data=MOCK_SAMPLER_CURRICULUM_YAML)
def test_sampler_configs(mock_file):
opt = parse_command_line(["mytrainerpath"])
assert opt.parameter_randomization == {"sampler1": "foo"}
assert isinstance(opt.parameter_randomization["sampler1"], UniformSettings)
assert len(opt.curriculum.keys()) == 2

3
ml-agents/mlagents/trainers/tests/test_simple_rl.py


from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.simple_env_manager import SimpleEnvManager
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary
from mlagents.trainers.settings import (

meta_curriculum=meta_curriculum,
train=True,
training_seed=seed,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)
# Begin training

5
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


from mlagents.tf_utils import tf
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.sampler_class import SamplerManager
@pytest.fixture

meta_curriculum=None,
train=True,
training_seed=99,
sampler_manager=SamplerManager({}),
resampling_interval=None,
)

meta_curriculum=None,
train=True,
training_seed=seed,
sampler_manager=SamplerManager({}),
resampling_interval=None,
)
numpy_random_seed.assert_called_with(seed)
tensorflow_set_seed.assert_called_with(seed)

7
docs/Training-ML-Agents.md


# < Same as above>
parameter_randomization:
resampling-interval: 5000
sampler-type: "uniform"
min_value: 0.5
max_value: 10
uniform:
min_value: 0.5
max_value: 10
gravity:
sampler-type: "multirange_uniform"

11
com.unity.ml-agents/Runtime/Sampler.cs.meta


fileFormatVersion: 2
guid: 39ce0ea5a8b2e47f696f6efc807029f6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

74
com.unity.ml-agents/Runtime/Sampler.cs


using System;
using System.Collections.Generic;
using Unity.MLAgents.Inference.Utils;
namespace Unity.MLAgents
{
/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
internal sealed class SamplerFactory
{
/// <summary>
/// Constructor.
/// </summary>
internal SamplerFactory()
{
}
public Func<float> CreateUniformSampler(float min, float max, int seed)
{
Random distr = new Random(seed);
return () => min + (float)distr.NextDouble() * (max - min);
}
public Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
{
RandomNormal distr = new RandomNormal(seed, mean, stddev);
return () => (float)distr.NextDouble();
}
public Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed)
{
//RNG
Random distr = new Random(seed);
// Will be used to normalize intervalFuncs
float sumIntervalSizes = 0;
//The number of intervals
int numIntervals = (int)(intervals.Count/2);
// List that will store interval lengths
float[] intervalSizes = new float[numIntervals];
// List that will store uniform distributions
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals];
// Collect all intervals and store as uniform distrus
// Collect all interval sizes
for(int i = 0; i < numIntervals; i++)
{
var min = intervals[2 * i];
var max = intervals[2 * i + 1];
var intervalSize = max - min;
sumIntervalSizes += intervalSize;
intervalSizes[i] = intervalSize;
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize;
}
// Normalize interval lengths
for(int i = 0; i < numIntervals; i++)
{
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes;
}
// Build cmf for intervals
for(int i = 1; i < numIntervals; i++)
{
intervalSizes[i] += intervalSizes[i - 1];
}
Multinomial intervalDistr = new Multinomial(seed);
float MultiRange()
{
int sampledInterval = intervalDistr.Sample(intervalSizes);
return intervalFuncs[sampledInterval].Invoke();
}
return MultiRange;
}
}
}

96
ml-agents/mlagents/trainers/tests/test_sampler_class.py


import pytest
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.sampler_class import (
UniformSampler,
MultiRangeUniformSampler,
GaussianSampler,
)
from mlagents.trainers.exception import TrainerError
def sampler_config_1():
return {
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
"gravity": {
"sampler-type": "multirange_uniform",
"intervals": [[8, 11], [15, 20]],
},
}
def check_value_in_intervals(val, intervals):
check_in_bounds = [a <= val <= b for a, b in intervals]
return any(check_in_bounds)
def test_sampler_config_1():
config = sampler_config_1()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["mass"], UniformSampler)
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)
cur_sample = sampler.sample_all()
# Check uniform sampler for mass
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
assert config["mass"]["min_value"] <= cur_sample["mass"]
assert config["mass"]["max_value"] >= cur_sample["mass"]
# Check multirange_uniform sampler for gravity
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
assert check_value_in_intervals(
cur_sample["gravity"], sampler.samplers["gravity"].intervals
)
def sampler_config_2():
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}}
def test_sampler_config_2():
config = sampler_config_2()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["angle"], GaussianSampler)
# Check angle gaussian sampler
assert sampler.samplers["angle"].mean == config["angle"]["mean"]
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
def test_empty_samplers():
empty_sampler = SamplerManager({})
assert empty_sampler.is_empty()
empty_cur_sample = empty_sampler.sample_all()
assert empty_cur_sample == {}
none_sampler = SamplerManager(None)
assert none_sampler.is_empty()
none_cur_sample = none_sampler.sample_all()
assert none_cur_sample == {}
def incorrect_uniform_sampler():
# Do not specify required arguments to uniform sampler
return {"mass": {"sampler-type": "uniform", "min-value": 10}}
def incorrect_sampler_config():
# Do not specify 'sampler-type' key
return {"mass": {"min-value": 2, "max-value": 30}}
def test_incorrect_uniform_sampler():
config = incorrect_uniform_sampler()
with pytest.raises(TrainerError):
SamplerManager(config)
def test_incorrect_sampler():
config = incorrect_sampler_config()
with pytest.raises(TrainerError):
SamplerManager(config)

193
ml-agents/mlagents/trainers/sampler_class.py


import numpy as np
from typing import Union, Optional, Type, List, Dict, Any
from abc import ABC, abstractmethod
from mlagents.trainers.exception import SamplerException
class Sampler(ABC):
@abstractmethod
def sample_parameter(self) -> float:
pass
class UniformSampler(Sampler):
"""
Uniformly draws a single sample in the range [min_value, max_value).
"""
def __init__(
self,
min_value: Union[int, float],
max_value: Union[int, float],
seed: Optional[int] = None,
):
"""
:param min_value: minimum value of the range to be sampled uniformly from
:param max_value: maximum value of the range to be sampled uniformly from
:param seed: Random seed used for making draws from the uniform sampler
"""
self.min_value = min_value
self.max_value = max_value
# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
def sample_parameter(self) -> float:
"""
Draws and returns a sample from the specified interval
"""
return self.random_state.uniform(self.min_value, self.max_value)
class MultiRangeUniformSampler(Sampler):
"""
Draws a single sample uniformly from the intervals provided. The sampler
first picks an interval based on a weighted selection, with the weights
assigned to an interval based on its range. After picking the range,
it proceeds to pick a value uniformly in that range.
"""
def __init__(
self, intervals: List[List[Union[int, float]]], seed: Optional[int] = None
):
"""
:param intervals: List of intervals to draw uniform samples from
:param seed: Random seed used for making uniform draws from the specified intervals
"""
self.intervals = intervals
# Measure the length of the intervals
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
cum_interval_length = sum(interval_lengths)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = [x / cum_interval_length for x in interval_lengths]
# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
def sample_parameter(self) -> float:
"""
Selects an interval to pick and then draws a uniform sample from the picked interval
"""
cur_min, cur_max = self.intervals[
self.random_state.choice(len(self.intervals), p=self.interval_weights)
]
return self.random_state.uniform(cur_min, cur_max)
class GaussianSampler(Sampler):
"""
Draw a single sample value from a normal (gaussian) distribution.
This sampler is characterized by the mean and the standard deviation.
"""
def __init__(
self,
mean: Union[float, int],
st_dev: Union[float, int],
seed: Optional[int] = None,
):
"""
:param mean: Specifies the mean of the gaussian distribution to draw from
:param st_dev: Specifies the standard devation of the gaussian distribution to draw from
:param seed: Random seed used for making gaussian draws from the sample
"""
self.mean = mean
self.st_dev = st_dev
# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
def sample_parameter(self) -> float:
"""
Returns a draw from the specified Gaussian distribution
"""
return self.random_state.normal(self.mean, self.st_dev)
class SamplerFactory:
"""
Maintain a directory of all samplers available.
Add new samplers using the register_sampler method.
"""
NAME_TO_CLASS = {
"uniform": UniformSampler,
"gaussian": GaussianSampler,
"multirange_uniform": MultiRangeUniformSampler,
}
@staticmethod
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
"""
Registers the sampe in the Sampler Factory to be used later
:param name: String name to set as key for the sampler_cls in the factory
:param sampler_cls: Sampler object to associate to the name in the factory
"""
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls
@staticmethod
def init_sampler_class(
name: str, params: Dict[str, Any], seed: Optional[int] = None
) -> Sampler:
"""
Initializes the sampler class associated with the name with the params
:param name: Name of the sampler in the factory to initialize
:param params: Parameters associated to the sampler attached to the name
:param seed: Random seed to be used to set deterministic random draws for the sampler
"""
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."
" Use the register_sample method to register the string"
" associated to your sampler in the SamplerFactory."
)
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
params["seed"] = seed
try:
return sampler_cls(**params)
except TypeError:
raise SamplerException(
"The sampler class associated to the " + name + " key in the factory "
"was not provided the required arguments. Please ensure that the sampler "
"config file consists of the appropriate keys for this sampler class."
)
class SamplerManager:
def __init__(
self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None
) -> None:
"""
:param reset_param_dict: Arguments needed for initializing the samplers
:param seed: Random seed to be used for drawing samples from the samplers
"""
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
assert isinstance(self.reset_param_dict, dict)
self.samplers: Dict[str, Sampler] = {}
for param_name, cur_param_dict in self.reset_param_dict.items():
if "sampler-type" not in cur_param_dict:
raise SamplerException(
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(
param_name
)
)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(
sampler_name, cur_param_dict, seed
)
self.samplers[param_name] = param_sampler
def is_empty(self) -> bool:
"""
Check for if sampler_manager is empty.
"""
return not bool(self.samplers)
def sample_all(self) -> Dict[str, float]:
"""
Loop over all samplers and draw a sample from each one for generating
next set of reset parameter values.
"""
res = {}
for param_name, param_sampler in list(self.samplers.items()):
res[param_name] = param_sampler.sample_parameter()
return res
正在加载...
取消
保存