|
|
|
|
|
|
import attr |
|
|
|
import pytest |
|
|
|
import yaml |
|
|
|
from typing import Dict |
|
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
|
|
from mlagents.trainers.settings import ( |
|
|
|
RunOptions, |
|
|
|
|
|
|
if isinstance(val, dict) or isinstance(val, list) or attr.has(val): |
|
|
|
# Note: this check doesn't check the contents of mutables. |
|
|
|
check_if_different(val, attr.asdict(testobj2, recurse=False)[key]) |
|
|
|
|
|
|
|
|
|
|
|
def check_dict_is_at_least( |
|
|
|
testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Check if everything present in the 1st dict is the same in the second dict. |
|
|
|
Excludes things that the second dict has but is not present in the heirarchy of the |
|
|
|
1st dict. Used to compare an underspecified config dict structure (e.g. as |
|
|
|
would be provided by a user) with a complete one (e.g. as exported by RunOptions). |
|
|
|
""" |
|
|
|
for key, val in testdict1.items(): |
|
|
|
if exceptions is not None and key in exceptions: |
|
|
|
continue |
|
|
|
assert key in testdict2 |
|
|
|
if isinstance(val, dict): |
|
|
|
check_dict_is_at_least(val, testdict2[key]) |
|
|
|
elif isinstance(val, list): |
|
|
|
assert isinstance(testdict2[key], list) |
|
|
|
for _el0, _el1 in zip(val, testdict2[key]): |
|
|
|
if isinstance(_el0, dict): |
|
|
|
check_dict_is_at_least(_el0, _el1) |
|
|
|
else: |
|
|
|
assert val == testdict2[key] |
|
|
|
else: # If not a dict, don't recurse into it |
|
|
|
assert val == testdict2[key] |
|
|
|
|
|
|
|
|
|
|
|
def test_is_new_instance(): |
|
|
|
|
|
|
ParameterRandomizationSettings.structure( |
|
|
|
"notadict", Dict[str, ParameterRandomizationSettings] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_defaults", [True, False]) |
|
|
|
def test_exportable_settings(use_defaults): |
|
|
|
""" |
|
|
|
Test that structuring and unstructuring a RunOptions object results in the same |
|
|
|
configuration representation. |
|
|
|
""" |
|
|
|
# Try to enable as many features as possible in this test YAML to hit all the |
|
|
|
# edge cases. Set as much as possible as non-default values to ensure no flukes. |
|
|
|
test_yaml = """ |
|
|
|
behaviors: |
|
|
|
3DBall: |
|
|
|
trainer_type: sac |
|
|
|
hyperparameters: |
|
|
|
learning_rate: 0.0004 |
|
|
|
learning_rate_schedule: constant |
|
|
|
batch_size: 64 |
|
|
|
buffer_size: 200000 |
|
|
|
buffer_init_steps: 100 |
|
|
|
tau: 0.006 |
|
|
|
steps_per_update: 10.0 |
|
|
|
save_replay_buffer: true |
|
|
|
init_entcoef: 0.5 |
|
|
|
reward_signal_steps_per_update: 10.0 |
|
|
|
network_settings: |
|
|
|
normalize: false |
|
|
|
hidden_units: 256 |
|
|
|
num_layers: 3 |
|
|
|
vis_encode_type: nature_cnn |
|
|
|
memory: |
|
|
|
memory_size: 1288 |
|
|
|
sequence_length: 12 |
|
|
|
reward_signals: |
|
|
|
extrinsic: |
|
|
|
gamma: 0.999 |
|
|
|
strength: 1.0 |
|
|
|
curiosity: |
|
|
|
gamma: 0.999 |
|
|
|
strength: 1.0 |
|
|
|
keep_checkpoints: 5 |
|
|
|
max_steps: 500000 |
|
|
|
time_horizon: 1000 |
|
|
|
summary_freq: 12000 |
|
|
|
checkpoint_interval: 1 |
|
|
|
threaded: true |
|
|
|
env_settings: |
|
|
|
env_path: test_env_path |
|
|
|
env_args: |
|
|
|
- test_env_args1 |
|
|
|
- test_env_args2 |
|
|
|
base_port: 12345 |
|
|
|
num_envs: 8 |
|
|
|
seed: 12345 |
|
|
|
engine_settings: |
|
|
|
width: 12345 |
|
|
|
height: 12345 |
|
|
|
quality_level: 12345 |
|
|
|
time_scale: 12345 |
|
|
|
target_frame_rate: 12345 |
|
|
|
capture_frame_rate: 12345 |
|
|
|
no_graphics: true |
|
|
|
checkpoint_settings: |
|
|
|
run_id: test_run_id |
|
|
|
initialize_from: test_directory |
|
|
|
load_model: false |
|
|
|
resume: true |
|
|
|
force: true |
|
|
|
train_model: false |
|
|
|
inference: false |
|
|
|
debug: true |
|
|
|
""" |
|
|
|
if not use_defaults: |
|
|
|
loaded_yaml = yaml.safe_load(test_yaml) |
|
|
|
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml)) |
|
|
|
else: |
|
|
|
run_options = RunOptions() |
|
|
|
dict_export = run_options.as_dict() |
|
|
|
|
|
|
|
if not use_defaults: # Don't need to check if no yaml |
|
|
|
check_dict_is_at_least(loaded_yaml, dict_export) |
|
|
|
|
|
|
|
# Re-import and verify has same elements |
|
|
|
run_options2 = RunOptions.from_dict(dict_export) |
|
|
|
second_export = run_options2.as_dict() |
|
|
|
|
|
|
|
# Check that the two exports are the same |
|
|
|
assert dict_export == second_export |