您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
143 行
5.2 KiB
143 行
5.2 KiB
from unittest.mock import MagicMock
|
|
import pytest
|
|
import mlagents.trainers.tests.mock_brain as mb
|
|
|
|
import os
|
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy
|
|
from mlagents.trainers.torch.components.bc.module import BCModule
|
|
from mlagents.trainers.settings import (
|
|
TrainerSettings,
|
|
BehavioralCloningSettings,
|
|
NetworkSettings,
|
|
)
|
|
|
|
|
|
def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample):
|
|
# model_path = env.external_brain_names[0]
|
|
trainer_config = TrainerSettings()
|
|
trainer_config.network_settings.memory = (
|
|
NetworkSettings.MemorySettings() if use_rnn else None
|
|
)
|
|
policy = TorchPolicy(
|
|
0, mock_behavior_specs, trainer_config, tanhresample, tanhresample
|
|
)
|
|
bc_module = BCModule(
|
|
policy,
|
|
settings=bc_settings,
|
|
policy_learning_rate=trainer_config.hyperparameters.learning_rate,
|
|
default_batch_size=trainer_config.hyperparameters.batch_size,
|
|
default_num_epoch=3,
|
|
)
|
|
return bc_module
|
|
|
|
|
|
def assert_stats_are_float(stats):
|
|
for _, item in stats.items():
|
|
assert isinstance(item, float)
|
|
|
|
|
|
# Test default values
|
|
def test_bcmodule_defaults():
|
|
# See if default values match
|
|
mock_specs = mb.create_mock_3dball_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
|
|
assert bc_module.num_epoch == 3
|
|
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size
|
|
# Assign strange values and see if it overrides properly
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
|
|
num_epoch=100,
|
|
batch_size=10000,
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
|
|
assert bc_module.num_epoch == 100
|
|
assert bc_module.batch_size == 10000
|
|
|
|
|
|
# Test with continuous control env and vector actions
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_update(is_sac):
|
|
mock_specs = mb.create_mock_3dball_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
|
|
stats = bc_module.update()
|
|
assert_stats_are_float(stats)
|
|
|
|
|
|
# Test with constant pretraining learning rate
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_constant_lr_update(is_sac):
|
|
mock_specs = mb.create_mock_3dball_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
|
|
steps=0,
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
|
|
stats = bc_module.update()
|
|
assert_stats_are_float(stats)
|
|
old_learning_rate = bc_module.current_lr
|
|
|
|
_ = bc_module.update()
|
|
assert old_learning_rate == bc_module.current_lr
|
|
|
|
|
|
# Test with constant pretraining learning rate
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_linear_lr_update(is_sac):
|
|
mock_specs = mb.create_mock_3dball_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
|
|
steps=100,
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
|
|
# Should decay by 10/100 * 0.0003 = 0.00003
|
|
bc_module.policy.get_current_step = MagicMock(return_value=10)
|
|
old_learning_rate = bc_module.current_lr
|
|
_ = bc_module.update()
|
|
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01)
|
|
|
|
|
|
# Test with RNN
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_rnn_update(is_sac):
|
|
mock_specs = mb.create_mock_3dball_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
|
|
stats = bc_module.update()
|
|
assert_stats_are_float(stats)
|
|
|
|
|
|
# Test with discrete control and visual observations
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_dc_visual_update(is_sac):
|
|
mock_specs = mb.create_mock_banana_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
|
|
stats = bc_module.update()
|
|
assert_stats_are_float(stats)
|
|
|
|
|
|
# Test with discrete control, visual observations and RNN
|
|
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
|
|
def test_bcmodule_rnn_dc_update(is_sac):
|
|
mock_specs = mb.create_mock_banana_behavior_specs()
|
|
bc_settings = BehavioralCloningSettings(
|
|
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
|
|
)
|
|
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
|
|
stats = bc_module.update()
|
|
assert_stats_are_float(stats)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main()
|