浏览代码

[bug-fix] Initialize-from being incorrectly loaded as "None" rather than None (#4175)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
3de1e660
共有 4 个文件被更改,包括 120 次插入3 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 2
      ml-agents/mlagents/trainers/learn.py
  3. 2
      ml-agents/mlagents/trainers/settings.py
  4. 117
      ml-agents/mlagents/trainers/tests/test_settings.py

2
com.unity.ml-agents/CHANGELOG.md


empty string). (#4155)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Fixed an error when setting `initialize_from` in the trainer confiiguration YAML to
`null`. (#4175)
## [1.1.0-preview] - 2020-06-10
### Major Changes

2
ml-agents/mlagents/trainers/learn.py


write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from
if checkpoint_settings.initialize_from is not None
else None
)
run_logs_dir = os.path.join(write_path, "run_logs")

2
ml-agents/mlagents/trainers/settings.py


@attr.s(auto_attribs=True)
class CheckpointSettings:
run_id: str = parser.get_default("run_id")
initialize_from: str = parser.get_default("initialize_from")
initialize_from: Optional[str] = parser.get_default("initialize_from")
load_model: bool = parser.get_default("load_model")
resume: bool = parser.get_default("resume")
force: bool = parser.get_default("force")

117
ml-agents/mlagents/trainers/tests/test_settings.py


import attr
import pytest
import yaml
from typing import Dict
from typing import Dict, List, Optional
from mlagents.trainers.settings import (
RunOptions,

if isinstance(val, dict) or isinstance(val, list) or attr.has(val):
# Note: this check doesn't check the contents of mutables.
check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
def check_dict_is_at_least(
testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None
) -> None:
"""
Check if everything present in the 1st dict is the same in the second dict.
Excludes things that the second dict has but is not present in the heirarchy of the
1st dict. Used to compare an underspecified config dict structure (e.g. as
would be provided by a user) with a complete one (e.g. as exported by RunOptions).
"""
for key, val in testdict1.items():
if exceptions is not None and key in exceptions:
continue
assert key in testdict2
if isinstance(val, dict):
check_dict_is_at_least(val, testdict2[key])
elif isinstance(val, list):
assert isinstance(testdict2[key], list)
for _el0, _el1 in zip(val, testdict2[key]):
if isinstance(_el0, dict):
check_dict_is_at_least(_el0, _el1)
else:
assert val == testdict2[key]
else: # If not a dict, don't recurse into it
assert val == testdict2[key]
def test_is_new_instance():

ParameterRandomizationSettings.structure(
"notadict", Dict[str, ParameterRandomizationSettings]
)
@pytest.mark.parametrize("use_defaults", [True, False])
def test_exportable_settings(use_defaults):
"""
Test that structuring and unstructuring a RunOptions object results in the same
configuration representation.
"""
# Try to enable as many features as possible in this test YAML to hit all the
# edge cases. Set as much as possible as non-default values to ensure no flukes.
test_yaml = """
behaviors:
3DBall:
trainer_type: sac
hyperparameters:
learning_rate: 0.0004
learning_rate_schedule: constant
batch_size: 64
buffer_size: 200000
buffer_init_steps: 100
tau: 0.006
steps_per_update: 10.0
save_replay_buffer: true
init_entcoef: 0.5
reward_signal_steps_per_update: 10.0
network_settings:
normalize: false
hidden_units: 256
num_layers: 3
vis_encode_type: nature_cnn
memory:
memory_size: 1288
sequence_length: 12
reward_signals:
extrinsic:
gamma: 0.999
strength: 1.0
curiosity:
gamma: 0.999
strength: 1.0
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
checkpoint_interval: 1
threaded: true
env_settings:
env_path: test_env_path
env_args:
- test_env_args1
- test_env_args2
base_port: 12345
num_envs: 8
seed: 12345
engine_settings:
width: 12345
height: 12345
quality_level: 12345
time_scale: 12345
target_frame_rate: 12345
capture_frame_rate: 12345
no_graphics: true
checkpoint_settings:
run_id: test_run_id
initialize_from: test_directory
load_model: false
resume: true
force: true
train_model: false
inference: false
debug: true
"""
if not use_defaults:
loaded_yaml = yaml.safe_load(test_yaml)
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml))
else:
run_options = RunOptions()
dict_export = run_options.as_dict()
if not use_defaults: # Don't need to check if no yaml
check_dict_is_at_least(loaded_yaml, dict_export)
# Re-import and verify has same elements
run_options2 = RunOptions.from_dict(dict_export)
second_export = run_options2.as_dict()
# Check that the two exports are the same
assert dict_export == second_export
正在加载...
取消
保存