浏览代码

Removed check_key and replaced with **param_dict for implicit type checks

/develop-generalizationTraining-TrainerController
sankalp04 5 年前
当前提交
7f96b47c
共有 1 个文件被更改,包括 45 次插入0 次删除
  1. 45
      ml-agents-envs/mlagents/envs/sampler_class.py

45
ml-agents-envs/mlagents/envs/sampler_class.py


import numpy as np
<<<<<<< HEAD
=======
from functools import *
>>>>>>> Removed check_key and replaced with **param_dict for implicit type checks
<<<<<<< HEAD
=======
class SamplerException(Exception):
pass
class Sampler(ABC):
>>>>>>> Removed check_key and replaced with **param_dict for implicit type checks
class Sampler(ABC):
@abstractmethod

class UniformSampler(Sampler):
<<<<<<< HEAD
"""
Uniformly draws a single sample in the range [min_value, max_value).
"""

self.max_value = max_value
def sample_parameter(self) -> float:
=======
# kwargs acts as a sink for extra unneeded args
def __init__(self, min_value, max_value, **kwargs):
self.min_value = min_value
self.max_value = max_value
def sample_parameter(self):
>>>>>>> Removed check_key and replaced with **param_dict for implicit type checks
<<<<<<< HEAD
class MultiRangeUniformSampler(Sampler):
"""

cur_min, cur_max = self.intervals[
np.random.choice(len(self.intervals), p=self.interval_weights)
]
=======
class MultiRangeUniformSampler(Sampler):
def __init__(self, intervals, **kwargs):
self.intervals = intervals
# Measure the length of the intervals
self.interval_lengths = list(map(lambda x: abs(x[1] - x[0]), self.intervals))
# Cumulative size of the intervals
self.cum_interval_length = reduce(lambda x,y: x + y, self.interval_lengths, 0)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = list(map(lambda x: x/self.cum_interval_length, self.interval_lengths))
def sample_parameter(self):
cur_min, cur_max = self.intervals[np.random.choice(len(self.intervals), p=self.interval_weights)]
>>>>>>> Removed check_key and replaced with **param_dict for implicit type checks
<<<<<<< HEAD
"""
Draw a single sample value from a normal (gaussian) distribution.
This sampler is characterized by the mean and the standard deviation.

def sample_parameter(self) -> float:
return np.random.normal(self.mean, self.st_dev)
=======
def __init__(self, mean, var, **kwargs):
self.mean = mean
self.var = var
def sample_parameter(self):
return np.random.normal(self.mean, self.var)
>>>>>>> Removed check_key and replaced with **param_dict for implicit type checks
class SamplerFactory:

正在加载...
取消
保存