Andrew Cohen
5 年前
当前提交
4464ca46
共有 12 个文件被更改,包括 223 次插入 和 78 次删除
-
11Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
-
5com.unity.ml-agents/Runtime/EnvironmentParameters.cs
-
25com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
-
5config/ppo/3DBall_randomize.yaml
-
2ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
-
48ml-agents/mlagents/trainers/learn.py
-
52ml-agents/mlagents/trainers/settings.py
-
1ml-agents/mlagents/trainers/subprocess_env_manager.py
-
26ml-agents/mlagents/trainers/trainer_controller.py
-
74com.unity.ml-agents/Runtime/Sampler.cs
-
11com.unity.ml-agents/Runtime/Sampler.cs.meta
-
41ml-agents/mlagents/trainers/sampler_utils.py
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents; |
|||
using Unity.MLAgents.Inference.Utils; |
|||
using UnityEngine; |
|||
using Random=UnityEngine.Random; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// The types of distributions from which to sample reset parameters.
|
|||
/// </summary>
|
|||
public enum SamplerType |
|||
{ |
|||
/// <summary>
|
|||
/// Samples a reset parameter from a uniform distribution.
|
|||
/// </summary>
|
|||
Uniform = 0, |
|||
|
|||
/// <summary>
|
|||
/// Samples a reset parameter from a Gaussian distribution.
|
|||
/// </summary>
|
|||
Gaussian = 1 |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
|
|||
/// </summary>
|
|||
public sealed class SamplerFactory |
|||
{ |
|||
|
|||
int m_Seed; |
|||
|
|||
/// <summary>
|
|||
/// Constructor.
|
|||
/// </summary>
|
|||
internal SamplerFactory(int seed) |
|||
{ |
|||
m_Seed = seed; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create the sampling distribution described by the encoding.
|
|||
/// </summary>
|
|||
/// <param name="encoding"> List of floats the describe sampling destribution.</param>
|
|||
public Func<float> CreateSampler(IList<float> encoding) |
|||
{ |
|||
if ((int)encoding[0] == (int)SamplerType.Uniform) |
|||
{ |
|||
return CreateUniformSampler(encoding[1], encoding[2]); |
|||
} |
|||
else if ((int)encoding[0] == (int)SamplerType.Gaussian) |
|||
{ |
|||
return CreateGaussianSampler(encoding[1], encoding[2]); |
|||
} |
|||
else{ |
|||
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); |
|||
return () => 0; |
|||
} |
|||
|
|||
} |
|||
|
|||
public Func<float> CreateUniformSampler(float min, float max) |
|||
{ |
|||
return () => Random.Range(min, max); |
|||
} |
|||
|
|||
public Func<float> CreateGaussianSampler(float mean, float stddev) |
|||
{ |
|||
RandomNormal distr = new RandomNormal(m_Seed, mean, stddev); |
|||
return () => (float)distr.NextDouble(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 39ce0ea5a8b2e47f696f6efc807029f6 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import numpy as np |
|||
from enum import Enum |
|||
from typing import Dict, List |
|||
|
|||
from mlagents.trainers.exception import SamplerException |
|||
|
|||
|
|||
class SamplerUtils: |
|||
""" |
|||
Maintain a directory of available samplers and their configs. |
|||
Validates sampler configs are correct. |
|||
""" |
|||
|
|||
NAME_TO_ARGS = { |
|||
"uniform": ["min_value", "max_value"], |
|||
"gaussian": ["mean", "st_dev"], |
|||
"multirangeuniform": ["intervals"], |
|||
} |
|||
NAME_TO_FLOAT_REPR = {"uniform": 0.0, "gaussian": 1.0, "multirangeuniform": 2.0} |
|||
|
|||
@staticmethod |
|||
def validate_and_structure_config( |
|||
param: str, config: Dict[str, List[float]] |
|||
) -> List[float]: |
|||
# Config must have a valid type |
|||
if ( |
|||
"sampler-type" not in config |
|||
or config["sampler-type"] not in SamplerUtils.NAME_TO_ARGS |
|||
): |
|||
raise SamplerException( |
|||
f"The sampler config for environment parameter {param} does not contain a sampler-type or the sampler-type is invalid." |
|||
) |
|||
# Check args are correct |
|||
sampler_type = config.pop("sampler-type") |
|||
if list(config.keys()) != SamplerUtils.NAME_TO_ARGS[sampler_type]: |
|||
raise SamplerException( |
|||
"The sampler config for environment parameter {} does not contain the correct arguments. Please specify {}.".format( |
|||
param, SamplerUtils.NAME_TO_ARGS[config["sampler-type"]] |
|||
) |
|||
) |
|||
return [SamplerUtils.NAME_TO_FLOAT_REPR[sampler_type]] + list(config.values()) |
撰写
预览
正在加载...
取消
保存
Reference in new issue