浏览代码

ignoring commit checks

/sampler-refactor-copy
Andrew Cohen 5 年前
当前提交
4464ca46
共有 12 个文件被更改,包括 223 次插入78 次删除
  1. 11
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  2. 5
      com.unity.ml-agents/Runtime/EnvironmentParameters.cs
  3. 25
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  4. 5
      config/ppo/3DBall_randomize.yaml
  5. 2
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  6. 48
      ml-agents/mlagents/trainers/learn.py
  7. 52
      ml-agents/mlagents/trainers/settings.py
  8. 1
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  9. 26
      ml-agents/mlagents/trainers/trainer_controller.py
  10. 74
      com.unity.ml-agents/Runtime/Sampler.cs
  11. 11
      com.unity.ml-agents/Runtime/Sampler.cs.meta
  12. 41
      ml-agents/mlagents/trainers/sampler_utils.py

11
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
var samplerType = m_ResetParams.GetWithDefault("mass-sampler-type", -1.0f);
var min = m_ResetParams.GetWithDefault("mass-min", -1.0f);
var max = m_ResetParams.GetWithDefault("mass-max", -1.0f);
Debug.Log(samplerType);
Debug.Log(min);
Debug.Log(max);
SetResetParameters();
}

public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
//m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
m_BallRb.mass = m_ResetParams.Sample("mass", 1.0f);
var scale = m_ResetParams.Sample("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

5
com.unity.ml-agents/Runtime/EnvironmentParameters.cs


return m_Channel.GetWithDefault(key, defaultValue);
}
public float Sample(string key, float defaultValue)
{
return m_Channel.Sample(key, defaultValue);
}
/// <summary>
/// Registers a callback action for the provided parameter key. Will overwrite any
/// existing action for that parameter. The callback will be called whenever the parameter

25
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, float> m_Parameters = new Dictionary<string, float>();
Dictionary<string, Func<float>> m_Samplers = new Dictionary<string, Func<float>>();
SamplerFactory m_SamplerFactory = new SamplerFactory(1);
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400";

}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
var samplerType = msg.ReadFloat32();
var statOne = msg.ReadFloat32();
var statTwo = msg.ReadFloat32();
m_Parameters[key+"-sampler-type"] = samplerType;
m_Parameters[key+"-min"] = statOne;
m_Parameters[key+"-max"] = statTwo;
var encoding = msg.ReadFloatList();
m_Samplers[key] = m_SamplerFactory.CreateSampler(encoding);
//var samplerType = msg.ReadFloat32();
//var statOne = msg.ReadFloat32();
//var statTwo = msg.ReadFloat32();
//m_Parameters[key+"-sampler-type"] = samplerType;
//m_Parameters[key+"-min"] = statOne;
//m_Parameters[key+"-max"] = statTwo;
}
else
{

float valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
}
public float Sample(string key, float defaultValue)
{
Func<float> valueOut;
bool hasKey = m_Samplers.TryGetValue(key, out valueOut);
return hasKey ? valueOut() : defaultValue;
}
/// <summary>

5
config/ppo/3DBall_randomize.yaml


threaded: true
parameter_randomization:
resampling-interval: 5000
gravity:
sampler-type: uniform
min_value: 7
max_value: 12
scale:
sampler-type: uniform
min_value: 0.75

2
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
# length of list
msg.write_int32(len(values))
for value in values:
msg.write_float32(value)
super().queue_message_to_send(msg)

48
ml-agents/mlagents/trainers/learn.py


)
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.exception import SamplerException
from mlagents.trainers.sampler_utils import SamplerUtils
from mlagents.trainers.settings import RunOptions
from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager

)
maybe_add_samplers(options.parameter_randomization, env_manager)
sampler_manager, resampling_interval = create_sampler_manager(
options.parameter_randomization, run_seed
)
trainer_factory = TrainerFactory(
options.behaviors,
checkpoint_settings.run_id,

maybe_meta_curriculum,
not checkpoint_settings.inference,
run_seed,
sampler_manager,
resampling_interval,
)
# Begin training

def maybe_add_samplers(sampler_config, env):
restructured_sampler_config: Dict[str, List[float]] = {}
if sampler_config is not None:
for v, config in sampler_config.items():
if v != "resampling-interval":
sampler_type = 0.0 if config["sampler-type"] == "uniform" else 1.0
restructured_sampler_config[v] = [
sampler_type,
config["min_value"],
config["max_value"],
]
env.reset(config=restructured_sampler_config)
def create_sampler_manager(sampler_config, run_seed=None):
resample_interval = None
# TODO send seed
# Filter arguments that do not exist in the environment
resample_interval = sampler_config.pop("resampling-interval")
if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
raise SamplerException(
"Specified resampling-interval is not valid. Please provide"
" a positive integer value for resampling-interval"
)
else:
raise SamplerException(
"Resampling interval was not specified in the sampler file."
" Please specify it with the 'resampling-interval' key in the sampler config file."
logger.warning(
"The resampling-interval is no longer necessary to specify for parameter randomization and is being ignored."
)
sampler_config.pop("resampling-interval")
for param, config in sampler_config.items():
list_of_config_floats = SamplerUtils.validate_and_structure_config(
param, config
sampler_manager = SamplerManager(sampler_config, run_seed)
return sampler_manager, resample_interval
restructured_sampler_config[param] = list_of_config_floats
env.reset(config=restructured_sampler_config)
def try_create_meta_curriculum(

52
ml-agents/mlagents/trainers/settings.py


learning_rate: float = 3e-4
class ParameterRandomizationType(Enum):
UNIFORM: str = "uniform"
GAUSSIAN: str = "gaussian"
MULTIRANGEUNIFORM: str = "multirangeuniform"
def to_settings(self) -> type:
_mapping = {
ParameterRandomizationType.UNIFORM: UniformSettings,
ParameterRandomizationType.GAUSSIAN: GaussianSettings,
ParameterRandomizationType.MULTIRANGEUNIFORM: MultiRangeUniformSettings,
}
return _mapping[self]
@attr.s(auto_attribs=True)
class ParameterRandomizationSettings:
@staticmethod
def structure(d: Mapping, t: type) -> Any:
"""
Helper method to structure a Dict of ParameterRandomizationSettings class. Meant to be registered with
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle
the special Enum selection of ParameterRandomizationSettings classes.
"""
if not isinstance(d, Mapping):
raise TrainerConfigError(
f"Unsupported parameter randomization configuration {d}."
)
d_final: Dict[ParameterRandomizationType, ParameterRandomizationSettings] = {}
for key, val in d.items():
enum_key = ParameterRandomizationType(key)
t = enum_key.to_settings()
d_final[enum_key] = strict_to_cls(val, t)
return d_final
@attr.s(auto_attribs=True)
class UniformSettings(ParameterRandomizationSettings):
min_value: float = 1.0
max_value: float = 1.0
@attr.s(auto_attribs=True)
class GaussianSettings(ParameterRandomizationSettings):
mean: float = 1.0
st_dev: float = 1.0
@attr.s(auto_attribs=True)
class MultiRangeUniformSettings(ParameterRandomizationSettings):
intervals: List[List[float]] = [[1.0, 1.0]]
@attr.s(auto_attribs=True)
class SelfPlaySettings:
save_steps: int = 20000

1
ml-agents/mlagents/trainers/subprocess_env_manager.py


_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
elif req.cmd == EnvironmentCommand.RESET:
for k, v in req.payload.items():
print(k, v)
if isinstance(v, float):
env_parameters.set_float_parameter(k, v)
elif isinstance(v, list):

26
ml-agents/mlagents/trainers/trainer_controller.py


UnityCommunicationException,
UnityCommunicatorStoppedException,
)
from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import (
hierarchical_timer,
timed,

meta_curriculum: Optional[MetaCurriculum],
train: bool,
training_seed: int,
sampler_manager: SamplerManager,
resampling_interval: Optional[int],
):
"""
:param output_path: Path to save the model.

:param meta_curriculum: MetaCurriculum object which stores information about all curricula.
:param train: Whether to train model, or only run inference.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
:param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters.
:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
:param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
"""
self.trainers: Dict[str, Trainer] = {}

self.save_freq = save_freq
self.train_model = train
self.meta_curriculum = meta_curriculum
self.sampler_manager = sampler_manager
self.resampling_interval = resampling_interval
self.ghost_controller = self.trainer_factory.ghost_controller
self.trainer_threads: List[threading.Thread] = []

A Data structure corresponding to the initial reset state of the
environment.
"""
sampled_reset_param = self.sampler_manager.sample_all()
sampled_reset_param.update(new_meta_curriculum_config)
env.reset(config=sampled_reset_param)
env.reset(config=new_meta_curriculum_config)
def _should_save_model(self, global_step: int) -> bool:
return (

n_steps = self.advance(env_manager)
for _ in range(n_steps):
global_step += 1
self.reset_env_if_ready(env_manager, global_step)
self.reset_env_if_ready(env_manager)
if self._should_save_model(global_step):
self._save_model()
# Stop advancing trainers

if changed:
self.trainers[brain_name].reward_buffer.clear()
def reset_env_if_ready(self, env: EnvManager, steps: int) -> None:
def reset_env_if_ready(self, env: EnvManager) -> None:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {

# If any lessons were incremented or the environment is
# ready to be reset
meta_curriculum_reset = any(lessons_incremented.values())
# Check if we are performing generalization training and we have finished the
# specified number of steps for the lesson
generalization_reset = (
not self.sampler_manager.is_empty()
and (steps != 0)
and (self.resampling_interval)
and (steps % self.resampling_interval == 0)
)
# If ghost trainer swapped teams
if meta_curriculum_reset or generalization_reset or ghost_controller_reset:
if meta_curriculum_reset or ghost_controller_reset:
self.end_trainer_episodes(env, lessons_incremented)
@timed

74
com.unity.ml-agents/Runtime/Sampler.cs


using System;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Inference.Utils;
using UnityEngine;
using Random=UnityEngine.Random;
namespace Unity.MLAgents
{
/// <summary>
/// The types of distributions from which to sample reset parameters.
/// </summary>
public enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.
/// </summary>
Uniform = 0,
/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1
}
/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
public sealed class SamplerFactory
{
int m_Seed;
/// <summary>
/// Constructor.
/// </summary>
internal SamplerFactory(int seed)
{
m_Seed = seed;
}
/// <summary>
/// Create the sampling distribution described by the encoding.
/// </summary>
/// <param name="encoding"> List of floats the describe sampling destribution.</param>
public Func<float> CreateSampler(IList<float> encoding)
{
if ((int)encoding[0] == (int)SamplerType.Uniform)
{
return CreateUniformSampler(encoding[1], encoding[2]);
}
else if ((int)encoding[0] == (int)SamplerType.Gaussian)
{
return CreateGaussianSampler(encoding[1], encoding[2]);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
return () => 0;
}
}
public Func<float> CreateUniformSampler(float min, float max)
{
return () => Random.Range(min, max);
}
public Func<float> CreateGaussianSampler(float mean, float stddev)
{
RandomNormal distr = new RandomNormal(m_Seed, mean, stddev);
return () => (float)distr.NextDouble();
}
}
}

11
com.unity.ml-agents/Runtime/Sampler.cs.meta


fileFormatVersion: 2
guid: 39ce0ea5a8b2e47f696f6efc807029f6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

41
ml-agents/mlagents/trainers/sampler_utils.py


import numpy as np
from enum import Enum
from typing import Dict, List
from mlagents.trainers.exception import SamplerException
class SamplerUtils:
"""
Maintain a directory of available samplers and their configs.
Validates sampler configs are correct.
"""
NAME_TO_ARGS = {
"uniform": ["min_value", "max_value"],
"gaussian": ["mean", "st_dev"],
"multirangeuniform": ["intervals"],
}
NAME_TO_FLOAT_REPR = {"uniform": 0.0, "gaussian": 1.0, "multirangeuniform": 2.0}
@staticmethod
def validate_and_structure_config(
param: str, config: Dict[str, List[float]]
) -> List[float]:
# Config must have a valid type
if (
"sampler-type" not in config
or config["sampler-type"] not in SamplerUtils.NAME_TO_ARGS
):
raise SamplerException(
f"The sampler config for environment parameter {param} does not contain a sampler-type or the sampler-type is invalid."
)
# Check args are correct
sampler_type = config.pop("sampler-type")
if list(config.keys()) != SamplerUtils.NAME_TO_ARGS[sampler_type]:
raise SamplerException(
"The sampler config for environment parameter {} does not contain the correct arguments. Please specify {}.".format(
param, SamplerUtils.NAME_TO_ARGS[config["sampler-type"]]
)
)
return [SamplerUtils.NAME_TO_FLOAT_REPR[sampler_type]] + list(config.values())
正在加载...
取消
保存