浏览代码

error properly when a keyword is not followed by a valid config in yaml

/sampler-refactor-copy
Andrew Cohen 4 年前
当前提交
b790ce76
共有 2 个文件被更改,包括 14 次插入9 次删除
  1. 10
      ml-agents/mlagents/trainers/learn.py
  2. 13
      ml-agents/mlagents/trainers/settings.py

10
ml-agents/mlagents/trainers/learn.py


def maybe_add_samplers(
sampler_config: Optional[Dict], env: SubprocessEnvManager
) -> None:
restructured_sampler_config: Dict[str, List[float]] = {}
if "resampling-interval" in sampler_config:
logger.warning(
"The resampling-interval is no longer necessary for parameter randomization. It is being ignored."
)
sampler_config.pop("resampling-interval")
for param, config in sampler_config.items():
restructured_sampler_config[param] = config
env.reset(config=restructured_sampler_config)
env.reset(config=sampler_config)
def try_create_meta_curriculum(

13
ml-agents/mlagents/trainers/settings.py


from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.models import ScheduleType, EncoderType
from mlagents_envs import logging_util
logger = logging_util.get_logger(__name__)
def check_and_structure(key: str, value: Any, class_type: type) -> Any:
attr_fields_dict = attr.fields_dict(class_type)

)
d_final: Dict[str, List[float]] = {}
for param, param_config in d.items():
if param == "resampling-interval":
logger.warning(
"The resampling-interval is no longer necessary for parameter randomization. It is being ignored."
)
continue
if not isinstance(param_config, Mapping):
raise TrainerConfigError(
f"Unsupported distribution configuration {param_config}."
)
for key, val in param_config.items():
enum_key = ParameterRandomizationType(key)
t = enum_key.to_settings()

正在加载...
取消
保存