浏览代码

Change samplers to use random state to allow consistency in reset par… (#2398)

* Change samplers to use random state to allow consistency in reset parameter draws for a specified seed
/hotfix-v0.9.2a
Ervin T 5 年前
当前提交
184b5d5a
共有 2 个文件被更改,包括 79 次插入13 次删除
  1. 86
      ml-agents-envs/mlagents/envs/sampler_class.py
  2. 6
      ml-agents/mlagents/trainers/learn.py

86
ml-agents-envs/mlagents/envs/sampler_class.py


"""
def __init__(
self, min_value: Union[int, float], max_value: Union[int, float], **kwargs
self,
min_value: Union[int, float],
max_value: Union[int, float],
seed: Optional[int] = None,
**kwargs
"""
:param min_value: minimum value of the range to be sampled uniformly from
:param max_value: maximum value of the range to be sampled uniformly from
:param seed: Random seed used for making draws from the uniform sampler
"""
# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
return np.random.uniform(self.min_value, self.max_value)
"""
Draws and returns a sample from the specified interval
"""
return self.random_state.uniform(self.min_value, self.max_value)
class MultiRangeUniformSampler(Sampler):

it proceeds to pick a value uniformly in that range.
"""
def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None:
def __init__(
self,
intervals: List[List[Union[int, float]]],
seed: Optional[int] = None,
**kwargs
) -> None:
"""
:param intervals: List of intervals to draw uniform samples from
:param seed: Random seed used for making uniform draws from the specified intervals
"""
self.intervals = intervals
# Measure the length of the intervals
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]

# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
"""
Selects an interval to pick and then draws a uniform sample from the picked interval
"""
np.random.choice(len(self.intervals), p=self.interval_weights)
self.random_state.choice(len(self.intervals), p=self.interval_weights)
return np.random.uniform(cur_min, cur_max)
return self.random_state.uniform(cur_min, cur_max)
class GaussianSampler(Sampler):

"""
def __init__(
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs
self,
mean: Union[float, int],
st_dev: Union[float, int],
seed: Optional[int] = None,
**kwargs
"""
:param mean: Specifies the mean of the gaussian distribution to draw from
:param st_dev: Specifies the standard devation of the gaussian distribution to draw from
:param seed: Random seed used for making gaussian draws from the sample
"""
# Draw from random state to allow for consistent reset parameter draw for a seed
self.random_state = np.random.RandomState(seed)
return np.random.normal(self.mean, self.st_dev)
"""
Returns a draw from the specified Gaussian distribution
"""
return self.random_state.normal(self.mean, self.st_dev)
class SamplerFactory:

@staticmethod
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
"""
Registers the sampe in the Sampler Factory to be used later
:param name: String name to set as key for the sampler_cls in the factory
:param sampler_cls: Sampler object to associate to the name in the factory
"""
def init_sampler_class(name: str, params: Dict[str, Any]):
def init_sampler_class(
name: str, params: Dict[str, Any], seed: Optional[int] = None
) -> Sampler:
"""
Initializes the sampler class associated with the name with the params
:param name: Name of the sampler in the factory to initialize
:param params: Parameters associated to the sampler attached to the name
:param seed: Random seed to be used to set deterministic random draws for the sampler
"""
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."

sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
params["seed"] = seed
try:
return sampler_cls(**params)
except TypeError:

class SamplerManager:
def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
def __init__(
self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None
) -> None:
"""
:param reset_param_dict: Arguments needed for initializing the samplers
:param seed: Random seed to be used for drawing samples from the samplers
"""
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
assert isinstance(self.reset_param_dict, dict)
self.samplers: Dict[str, Sampler] = {}

)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(
sampler_name, cur_param_dict
sampler_name, cur_param_dict, seed
)
self.samplers[param_name] = param_sampler

return not bool(self.samplers)
def sample_all(self) -> Dict[str, float]:
"""
Loop over all samplers and draw a sample from each one for generating
next set of reset parameter values.
"""
res = {}
for param_name, param_sampler in list(self.samplers.items()):
res[param_name] = param_sampler.sample_parameter()

6
ml-agents/mlagents/trainers/learn.py


env = SubprocessEnvManager(env_factory, num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(curriculum_folder, env)
sampler_manager, resampling_interval = create_sampler_manager(
sampler_file_path, env.reset_parameters
sampler_file_path, env.reset_parameters, run_seed
)
# Create controller and begin training.

tc.start_learning(env, trainer_config)
def create_sampler_manager(sampler_file_path, env_reset_params):
def create_sampler_manager(sampler_file_path, env_reset_params, run_seed=None):
sampler_config = None
resample_interval = None
if sampler_file_path is not None:

"Resampling interval was not specified in the sampler file."
" Please specify it with the 'resampling-interval' key in the sampler config file."
)
sampler_manager = SamplerManager(sampler_config)
sampler_manager = SamplerManager(sampler_config, run_seed)
return sampler_manager, resample_interval

正在加载...
取消
保存