|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
from mlagents.trainers.common.nn_policy import NNPolicy |
|
|
|
from mlagents.trainers.sac.policy import SACPolicy |
|
|
|
from mlagents.trainers.components.bc.module import BCModule |
|
|
|
|
|
|
|
|
|
|
|
def ppo_dummy_config(): |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def sac_dummy_config(): |
|
|
|
return yaml.safe_load( |
|
|
|
""" |
|
|
|
trainer: sac |
|
|
|
batch_size: 128 |
|
|
|
buffer_size: 50000 |
|
|
|
buffer_init_steps: 0 |
|
|
|
hidden_units: 128 |
|
|
|
init_entcoef: 1.0 |
|
|
|
learning_rate: 3.0e-4 |
|
|
|
max_steps: 5.0e4 |
|
|
|
memory_size: 256 |
|
|
|
normalize: false |
|
|
|
num_update: 1 |
|
|
|
train_interval: 1 |
|
|
|
num_layers: 2 |
|
|
|
time_horizon: 64 |
|
|
|
sequence_length: 64 |
|
|
|
summary_freq: 1000 |
|
|
|
tau: 0.005 |
|
|
|
use_recurrent: false |
|
|
|
vis_encode_type: simple |
|
|
|
behavioral_cloning: |
|
|
|
demo_path: ./Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo |
|
|
|
strength: 1.0 |
|
|
|
steps: 10000000 |
|
|
|
reward_signals: |
|
|
|
extrinsic: |
|
|
|
strength: 1.0 |
|
|
|
gamma: 0.99 |
|
|
|
""" |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_policy_with_bc_mock(mock_brain, trainer_config, use_rnn, demo_file): |
|
|
|
def create_bc_module(mock_brain, trainer_config, use_rnn, demo_file, tanhresample): |
|
|
|
# model_path = env.external_brain_names[0] |
|
|
|
trainer_config["model_path"] = "testpath" |
|
|
|
trainer_config["keep_checkpoints"] = 3 |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
policy = ( |
|
|
|
NNPolicy(0, mock_brain, trainer_config, False, False) |
|
|
|
if trainer_config["trainer"] == "ppo" |
|
|
|
else SACPolicy(0, mock_brain, trainer_config, False, False) |
|
|
|
policy = NNPolicy( |
|
|
|
0, mock_brain, trainer_config, False, False, tanhresample, tanhresample |
|
|
|
return policy |
|
|
|
with policy.graph.as_default(): |
|
|
|
bc_module = BCModule( |
|
|
|
policy, |
|
|
|
policy_learning_rate=trainer_config["learning_rate"], |
|
|
|
default_batch_size=trainer_config["batch_size"], |
|
|
|
default_num_epoch=3, |
|
|
|
**trainer_config["behavioral_cloning"], |
|
|
|
) |
|
|
|
policy.initialize_or_load() |
|
|
|
return bc_module |
|
|
|
|
|
|
|
|
|
|
|
# Test default values |
|
|
|
|
|
|
trainer_config = ppo_dummy_config() |
|
|
|
policy = create_policy_with_bc_mock(mock_brain, trainer_config, False, "test.demo") |
|
|
|
assert policy.bc_module.num_epoch == 3 |
|
|
|
assert policy.bc_module.batch_size == trainer_config["batch_size"] |
|
|
|
bc_module = create_bc_module(mock_brain, trainer_config, False, "test.demo", False) |
|
|
|
assert bc_module.num_epoch == 3 |
|
|
|
assert bc_module.batch_size == trainer_config["batch_size"] |
|
|
|
policy = create_policy_with_bc_mock(mock_brain, trainer_config, False, "test.demo") |
|
|
|
assert policy.bc_module.num_epoch == 100 |
|
|
|
assert policy.bc_module.batch_size == 10000 |
|
|
|
bc_module = create_bc_module(mock_brain, trainer_config, False, "test.demo", False) |
|
|
|
assert bc_module.num_epoch == 100 |
|
|
|
assert bc_module.batch_size == 10000 |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
def test_bcmodule_update(trainer_config): |
|
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"]) |
|
|
|
def test_bcmodule_update(is_sac): |
|
|
|
policy = create_policy_with_bc_mock(mock_brain, trainer_config, False, "test.demo") |
|
|
|
stats = policy.bc_module.update() |
|
|
|
bc_module = create_bc_module( |
|
|
|
mock_brain, ppo_dummy_config(), False, "test.demo", is_sac |
|
|
|
) |
|
|
|
stats = bc_module.update() |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
def test_bcmodule_constant_lr_update(trainer_config): |
|
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"]) |
|
|
|
def test_bcmodule_constant_lr_update(is_sac): |
|
|
|
trainer_config = ppo_dummy_config() |
|
|
|
policy = create_policy_with_bc_mock(mock_brain, trainer_config, False, "test.demo") |
|
|
|
stats = policy.bc_module.update() |
|
|
|
bc_module = create_bc_module(mock_brain, trainer_config, False, "test.demo", is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
old_learning_rate = policy.bc_module.current_lr |
|
|
|
old_learning_rate = bc_module.current_lr |
|
|
|
stats = policy.bc_module.update() |
|
|
|
assert old_learning_rate == policy.bc_module.current_lr |
|
|
|
stats = bc_module.update() |
|
|
|
assert old_learning_rate == bc_module.current_lr |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
def test_bcmodule_rnn_update(trainer_config): |
|
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"]) |
|
|
|
def test_bcmodule_rnn_update(is_sac): |
|
|
|
policy = create_policy_with_bc_mock(mock_brain, trainer_config, True, "test.demo") |
|
|
|
stats = policy.bc_module.update() |
|
|
|
bc_module = create_bc_module( |
|
|
|
mock_brain, ppo_dummy_config(), True, "test.demo", is_sac |
|
|
|
) |
|
|
|
stats = bc_module.update() |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
def test_bcmodule_dc_visual_update(trainer_config): |
|
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"]) |
|
|
|
def test_bcmodule_dc_visual_update(is_sac): |
|
|
|
policy = create_policy_with_bc_mock( |
|
|
|
mock_brain, trainer_config, False, "testdcvis.demo" |
|
|
|
bc_module = create_bc_module( |
|
|
|
mock_brain, ppo_dummy_config(), False, "testdcvis.demo", is_sac |
|
|
|
stats = policy.bc_module.update() |
|
|
|
stats = bc_module.update() |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
def test_bcmodule_rnn_dc_update(trainer_config): |
|
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"]) |
|
|
|
def test_bcmodule_rnn_dc_update(is_sac): |
|
|
|
policy = create_policy_with_bc_mock( |
|
|
|
mock_brain, trainer_config, True, "testdcvis.demo" |
|
|
|
bc_module = create_bc_module( |
|
|
|
mock_brain, ppo_dummy_config(), True, "testdcvis.demo", is_sac |
|
|
|
stats = policy.bc_module.update() |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
|
|
|
|