|
|
|
|
|
|
class ParameterRandomizationSettings(abc.ABC): |
|
|
|
seed: int = parser.get_default("seed") |
|
|
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
""" |
|
|
|
Helper method to output sampler stats to console. |
|
|
|
""" |
|
|
|
raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.") |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def structure( |
|
|
|
d: Union[Mapping, float], t: type |
|
|
|
|
|
|
class ConstantSettings(ParameterRandomizationSettings): |
|
|
|
value: float = 0.0 |
|
|
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
""" |
|
|
|
Helper method to output sampler stats to console. |
|
|
|
""" |
|
|
|
return f"Float: value={self.value}" |
|
|
|
|
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
|
|
""" |
|
|
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
|
|
|
|
|
min_value: float = attr.ib() |
|
|
|
max_value: float = 1.0 |
|
|
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
""" |
|
|
|
Helper method to output sampler stats to console. |
|
|
|
""" |
|
|
|
return f"Uniform sampler: min={self.min_value}, max={self.max_value}" |
|
|
|
|
|
|
|
@min_value.default |
|
|
|
def _min_value_default(self): |
|
|
|
return 0.0 |
|
|
|
|
|
|
mean: float = 1.0 |
|
|
|
st_dev: float = 1.0 |
|
|
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
""" |
|
|
|
Helper method to output sampler stats to console. |
|
|
|
""" |
|
|
|
return f"Gaussian sampler: mean={self.mean}, stddev={self.st_dev}" |
|
|
|
|
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
|
|
""" |
|
|
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class MultiRangeUniformSettings(ParameterRandomizationSettings): |
|
|
|
intervals: List[Tuple[float, float]] = attr.ib() |
|
|
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
""" |
|
|
|
Helper method to output sampler stats to console. |
|
|
|
""" |
|
|
|
return f"MultiRangeUniform sampler: intervals={self.intervals}" |
|
|
|
|
|
|
|
@intervals.default |
|
|
|
def _intervals_default(self): |
|
|
|