|
|
|
|
|
|
from mlagents.trainers.tests.simple_test_envs import ( |
|
|
|
Simple1DEnvironment, |
|
|
|
Memory1DEnvironment, |
|
|
|
Record1DEnvironment, |
|
|
|
from mlagents.trainers.demo_loader import write_demo |
|
|
|
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|
|
|
DemonstrationMetaProto, |
|
|
|
) |
|
|
|
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto |
|
|
|
from mlagents_envs.communicator_objects.space_type_pb2 import discrete, continuous |
|
|
|
|
|
|
|
BRAIN_NAME = "1D" |
|
|
|
|
|
|
|
|
|
|
assert any(reward > success_threshold for reward in processed_rewards) and any( |
|
|
|
reward < success_threshold for reward in processed_rewards |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
|
def simple_record(tmpdir_factory): |
|
|
|
def record_demo(use_discrete, num_visual=0, num_vector=1): |
|
|
|
env = Record1DEnvironment( |
|
|
|
[BRAIN_NAME], |
|
|
|
use_discrete=use_discrete, |
|
|
|
num_visual=num_visual, |
|
|
|
num_vector=num_vector, |
|
|
|
n_demos=100, |
|
|
|
) |
|
|
|
# If we want to use true demos, we can solve the env in the usual way |
|
|
|
# Otherwise, we can just call solve to execute the optimal policy |
|
|
|
env.solve() |
|
|
|
agent_info_protos = env.demonstration_protos[BRAIN_NAME] |
|
|
|
meta_data_proto = DemonstrationMetaProto() |
|
|
|
brain_param_proto = BrainParametersProto( |
|
|
|
vector_action_size=[1], |
|
|
|
vector_action_descriptions=[""], |
|
|
|
vector_action_space_type=discrete if use_discrete else continuous, |
|
|
|
brain_name=BRAIN_NAME, |
|
|
|
is_training=True, |
|
|
|
) |
|
|
|
action_type = "Discrete" if use_discrete else "Continuous" |
|
|
|
demo_path_name = "1DTest" + action_type + ".demo" |
|
|
|
demo_path = str(tmpdir_factory.mktemp("tmp_demo").join(demo_path_name)) |
|
|
|
write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos) |
|
|
|
return demo_path |
|
|
|
|
|
|
|
return record_demo |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_discrete", [True, False]) |
|
|
|
@pytest.mark.parametrize("trainer_config", [PPO_CONFIG, SAC_CONFIG]) |
|
|
|
def test_gail(simple_record, use_discrete, trainer_config): |
|
|
|
demo_path = simple_record(use_discrete) |
|
|
|
env = Simple1DEnvironment([BRAIN_NAME], use_discrete=use_discrete, step_size=0.2) |
|
|
|
override_vals = { |
|
|
|
"max_steps": 500, |
|
|
|
"behavioral_cloning": {"demo_path": demo_path, "strength": 1.0, "steps": 1000}, |
|
|
|
"reward_signals": { |
|
|
|
"gail": { |
|
|
|
"strength": 1.0, |
|
|
|
"gamma": 0.99, |
|
|
|
"encoding_size": 32, |
|
|
|
"demo_path": demo_path, |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
|
config = generate_config(trainer_config, override_vals) |
|
|
|
_check_environment_trains(env, config, success_threshold=0.9) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_discrete", [True, False]) |
|
|
|
def test_gail_visual_ppo(simple_record, use_discrete): |
|
|
|
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0) |
|
|
|
env = Simple1DEnvironment( |
|
|
|
[BRAIN_NAME], |
|
|
|
num_visual=1, |
|
|
|
num_vector=0, |
|
|
|
use_discrete=use_discrete, |
|
|
|
step_size=0.2, |
|
|
|
) |
|
|
|
override_vals = { |
|
|
|
"max_steps": 1000, |
|
|
|
"learning_rate": 3.0e-4, |
|
|
|
"behavioral_cloning": {"demo_path": demo_path, "strength": 1.0, "steps": 1000}, |
|
|
|
"reward_signals": { |
|
|
|
"gail": { |
|
|
|
"strength": 1.0, |
|
|
|
"gamma": 0.99, |
|
|
|
"encoding_size": 32, |
|
|
|
"demo_path": demo_path, |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
|
config = generate_config(PPO_CONFIG, override_vals) |
|
|
|
_check_environment_trains(env, config, success_threshold=0.9) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_discrete", [True, False]) |
|
|
|
def test_gail_visual_sac(simple_record, use_discrete): |
|
|
|
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0) |
|
|
|
env = Simple1DEnvironment( |
|
|
|
[BRAIN_NAME], |
|
|
|
num_visual=1, |
|
|
|
num_vector=0, |
|
|
|
use_discrete=use_discrete, |
|
|
|
step_size=0.2, |
|
|
|
) |
|
|
|
override_vals = { |
|
|
|
"max_steps": 500, |
|
|
|
"batch_size": 16, |
|
|
|
"learning_rate": 3.0e-4, |
|
|
|
"behavioral_cloning": {"demo_path": demo_path, "strength": 1.0, "steps": 1000}, |
|
|
|
"reward_signals": { |
|
|
|
"gail": { |
|
|
|
"strength": 1.0, |
|
|
|
"gamma": 0.99, |
|
|
|
"encoding_size": 32, |
|
|
|
"demo_path": demo_path, |
|
|
|
} |
|
|
|
}, |
|
|
|
} |
|
|
|
config = generate_config(SAC_CONFIG, override_vals) |
|
|
|
_check_environment_trains(env, config, success_threshold=0.9) |