浏览代码

type checks for parameter randomization settings/enforces float encoding

/sampler-refactor-copy
Andrew Cohen 5 年前
当前提交
5ffd9761
共有 4 个文件被更改,包括 17 次插入7 次删除
  1. 2
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  2. 4
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  3. 13
      ml-agents/mlagents/trainers/settings.py
  4. 5
      ml-agents/mlagents/trainers/subprocess_env_manager.py

2
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut() : defaultValue;
}
/// <summary>
/// Registers a callback for the associated parameter key. Will overwrite any existing
/// actions for this parameter key.

4
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


def set_sampler_parameters(self, key: str, values: List[float]) -> None:
"""
Sets a float encoding of an environment parameter sampelr.
Sets a float encoding of an environment parameter sampler.
:param value: The float encoding of the sampler.
:param values: The float encoding of the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)

13
ml-agents/mlagents/trainers/settings.py


from enum import Enum
import collections
import argparse
import abc
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser
from mlagents.trainers.cli_utils import load_config

@attr.s(auto_attribs=True)
class ParameterRandomizationSettings:
class ParameterRandomizationSettings(abc.ABC):
@staticmethod
def structure(d: Mapping, t: type) -> Any:
"""

for key, val in param_config.items():
enum_key = ParameterRandomizationType(key)
t = enum_key.to_settings()
d_final[param] = strict_to_cls(val, t).to_float_encoding()
d_final[param] = strict_to_cls(val, t)
@abc.abstractmethod
def to_float_encoding(self) -> List[float]:
"Returns the float encoding of the sampler"
pass
@attr.s(auto_attribs=True)
class UniformSettings(ParameterRandomizationSettings):

)
def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by the min and max values"
return [0.0, self.min_value, self.max_value]

st_dev: float = 1.0
def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by the mean and standard deviation"
return [1.0, self.mean, self.st_dev]

)
def to_float_encoding(self) -> List[float]:
"Returns the sampler type followed by a flattened list of the interval values"
floats: List[float] = []
for interval in self.intervals:
floats += interval

5
ml-agents/mlagents/trainers/subprocess_env_manager.py


get_timer_root,
)
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.settings import ParameterRandomizationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,

for k, v in req.payload.items():
if isinstance(v, float):
env_parameters.set_float_parameter(k, v)
elif isinstance(v, list):
env_parameters.set_sampler_parameters(k, v)
elif isinstance(v, ParameterRandomizationSettings):
env_parameters.set_sampler_parameters(k, v.to_float_encoding())
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)

正在加载...
取消
保存