|
|
|
|
|
|
check_environment_trains(env, {BRAIN_NAME: config}) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) |
|
|
|
@pytest.mark.parametrize("num_var_len", [1, 2]) |
|
|
|
@pytest.mark.parametrize("num_visual", [0, 1]) |
|
|
|
def test_var_len_obs_ppo(num_visual, num_var_len, action_sizes): |
|
|
|
env = SimpleEnvironment( |
|
|
|
[BRAIN_NAME], |
|
|
|
action_sizes=action_sizes, |
|
|
|
num_visual=num_visual, |
|
|
|
num_vector=0, |
|
|
|
num_var_len=num_var_len, |
|
|
|
step_size=0.2, |
|
|
|
) |
|
|
|
new_hyperparams = attr.evolve( |
|
|
|
PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4 |
|
|
|
) |
|
|
|
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams) |
|
|
|
check_environment_trains(env, {BRAIN_NAME: config}) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_visual", [1, 2]) |
|
|
|
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"]) |
|
|
|
def test_visual_advanced_ppo(vis_encode_type, num_visual): |
|
|
|
|
|
|
[BRAIN_NAME], |
|
|
|
action_sizes=action_sizes, |
|
|
|
num_visual=num_visual, |
|
|
|
num_vector=0, |
|
|
|
step_size=0.2, |
|
|
|
) |
|
|
|
new_hyperparams = attr.evolve( |
|
|
|
SAC_TORCH_CONFIG.hyperparameters, batch_size=16, learning_rate=3e-4 |
|
|
|
) |
|
|
|
config = attr.evolve(SAC_TORCH_CONFIG, hyperparameters=new_hyperparams) |
|
|
|
check_environment_trains(env, {BRAIN_NAME: config}) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) |
|
|
|
@pytest.mark.parametrize("num_var_len", [1, 2]) |
|
|
|
def test_var_len_obs_sac(num_var_len, action_sizes): |
|
|
|
env = SimpleEnvironment( |
|
|
|
[BRAIN_NAME], |
|
|
|
action_sizes=action_sizes, |
|
|
|
num_visual=0, |
|
|
|
num_var_len=num_var_len, |
|
|
|
num_vector=0, |
|
|
|
step_size=0.2, |
|
|
|
) |
|
|
|