|
|
|
|
|
|
for key, val in param_config.items(): |
|
|
|
enum_key = ParameterRandomizationType(key) |
|
|
|
t = enum_key.to_settings() |
|
|
|
d_final[param] = strict_to_cls(val, t).to_float() |
|
|
|
d_final[param] = strict_to_cls(val, t).to_float_encoding() |
|
|
|
min_value: float = 1.0 |
|
|
|
min_value: float = attr.ib() |
|
|
|
def to_float(self) -> List[float]: |
|
|
|
@min_value.default |
|
|
|
def _min_value_default(self): |
|
|
|
return 1.0 |
|
|
|
|
|
|
|
@min_value.validator |
|
|
|
def _check_intervals(self, attribute, value): |
|
|
|
|
|
|
|
def to_float_encoding(self) -> List[float]: |
|
|
|
return [0.0, self.min_value, self.max_value] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st_dev: float = 1.0 |
|
|
|
|
|
|
|
def to_float(self) -> List[float]: |
|
|
|
def to_float_encoding(self) -> List[float]: |
|
|
|
intervals: List[List[float]] = [[1.0, 1.0]] |
|
|
|
intervals: List[List[float]] = attr.ib() |
|
|
|
def to_float(self) -> List[float]: |
|
|
|
floats: List[float] = [] |
|
|
|
@intervals.default |
|
|
|
def _intervals_default(self): |
|
|
|
return [[1.0, 1.0]] |
|
|
|
|
|
|
|
@intervals.validator |
|
|
|
def _check_intervals(self, attribute, value): |
|
|
|
for interval in self.intervals: |
|
|
|
if len(interval) != 2: |
|
|
|
raise TrainerConfigError( |
|
|
|
|
|
|
raise TrainerConfigError( |
|
|
|
f"Minimum value is greater than maximum value in interval {interval}." |
|
|
|
) |
|
|
|
|
|
|
|
def to_float_encoding(self) -> List[float]: |
|
|
|
floats: List[float] = [] |
|
|
|
for interval in self.intervals: |
|
|
|
floats += interval |
|
|
|
return [2.0] + floats |
|
|
|
|
|
|
|