|
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
import cattr |
|
|
|
from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union |
|
|
|
from typing import ( |
|
|
|
Dict, |
|
|
|
Optional, |
|
|
|
List, |
|
|
|
Any, |
|
|
|
DefaultDict, |
|
|
|
Mapping, |
|
|
|
Tuple, |
|
|
|
Union, |
|
|
|
ClassVar, |
|
|
|
) |
|
|
|
from enum import Enum |
|
|
|
import collections |
|
|
|
import argparse |
|
|
|
|
|
|
import copy |
|
|
|
|
|
|
|
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser |
|
|
|
from mlagents.trainers.cli_utils import load_config |
|
|
|
|
|
|
|
|
|
|
def defaultdict_to_dict(d: DefaultDict) -> Dict: |
|
|
|
return {key: cattr.unstructure(val) for key, val in d.items()} |
|
|
|
|
|
|
|
|
|
|
|
def deep_update_dict(d: Dict, update_d: Mapping) -> None: |
|
|
|
""" |
|
|
|
Similar to dict.update(), but works for nested dicts of dicts as well. |
|
|
|
""" |
|
|
|
for key, val in update_d.items(): |
|
|
|
if key in d and isinstance(d[key], Mapping) and isinstance(val, Mapping): |
|
|
|
deep_update_dict(d[key], val) |
|
|
|
else: |
|
|
|
d[key] = val |
|
|
|
|
|
|
|
|
|
|
|
class SerializationSettings: |
|
|
|
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class TrainerSettings(ExportableSettings): |
|
|
|
default_override: ClassVar[Optional["TrainerSettings"]] = None |
|
|
|
trainer_type: TrainerType = TrainerType.PPO |
|
|
|
hyperparameters: HyperparamSettings = attr.ib() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def dict_to_defaultdict(d: Dict, t: type) -> DefaultDict: |
|
|
|
return collections.defaultdict( |
|
|
|
TrainerSettings, cattr.structure(d, Dict[str, TrainerSettings]) |
|
|
|
return TrainerSettings.DefaultTrainerDict( |
|
|
|
cattr.structure(d, Dict[str, TrainerSettings]) |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
cattr.register_structure_hook() and called with cattr.structure(). |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
d_copy.update(d) |
|
|
|
|
|
|
|
# Check if a default_settings was specified. If so, used those as the default |
|
|
|
# rather than an empty dict. |
|
|
|
if TrainerSettings.default_override is not None: |
|
|
|
d_copy.update(cattr.unstructure(TrainerSettings.default_override)) |
|
|
|
|
|
|
|
deep_update_dict(d_copy, d) |
|
|
|
|
|
|
|
for key, val in d_copy.items(): |
|
|
|
if attr.has(type(val)): |
|
|
|
|
|
|
d_copy[key] = check_and_structure(key, val, t) |
|
|
|
return t(**d_copy) |
|
|
|
|
|
|
|
class DefaultTrainerDict(collections.defaultdict): |
|
|
|
def __init__(self, *args): |
|
|
|
super().__init__(TrainerSettings, *args) |
|
|
|
|
|
|
|
def __missing__(self, key: Any) -> "TrainerSettings": |
|
|
|
if TrainerSettings.default_override is not None: |
|
|
|
return copy.deepcopy(TrainerSettings.default_override) |
|
|
|
else: |
|
|
|
return TrainerSettings() |
|
|
|
|
|
|
|
|
|
|
|
# COMMAND LINE ######################################################################### |
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class RunOptions(ExportableSettings): |
|
|
|
default_settings: Optional[TrainerSettings] = None |
|
|
|
factory=lambda: collections.defaultdict(TrainerSettings) |
|
|
|
factory=TrainerSettings.DefaultTrainerDict |
|
|
|
) |
|
|
|
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings) |
|
|
|
engine_settings: EngineSettings = attr.ib(factory=EngineSettings) |
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_dict(options_dict: Dict[str, Any]) -> "RunOptions": |
|
|
|
# If a default settings was specified, set the TrainerSettings class override |
|
|
|
if ( |
|
|
|
"default_settings" in options_dict.keys() |
|
|
|
and options_dict["default_settings"] is not None |
|
|
|
): |
|
|
|
TrainerSettings.default_override = cattr.structure( |
|
|
|
options_dict["default_settings"], TrainerSettings |
|
|
|
) |
|
|
|
return cattr.structure(options_dict, RunOptions) |