|
|
|
|
|
|
from mlagents.trainers.tests import mock_brain as mb |
|
|
|
from mlagents.trainers.tests.mock_brain import make_brain_parameters |
|
|
|
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.tests.test_reward_signals import ( # noqa: F401; pylint: disable=unused-variable |
|
|
|
curiosity_dummy_config, |
|
|
|
gail_dummy_config, |
|
|
|
|
|
|
num_epoch: 5 |
|
|
|
num_layers: 2 |
|
|
|
time_horizon: 64 |
|
|
|
sequence_length: 64 |
|
|
|
sequence_length: 16 |
|
|
|
summary_freq: 1000 |
|
|
|
use_recurrent: false |
|
|
|
normalize: true |
|
|
|
|
|
|
update_buffer["extrinsic_value_estimates"] = update_buffer["environment_rewards"] |
|
|
|
optimizer.update( |
|
|
|
update_buffer, |
|
|
|
num_sequences=update_buffer.num_experiences // dummy_config["sequence_length"], |
|
|
|
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.update_buffer = buffer |
|
|
|
trainer._update_policy() |
|
|
|
# Make batch length a larger multiple of sequence length |
|
|
|
trainer.trainer_parameters["batch_size"] = 128 |
|
|
|
trainer._update_policy() |
|
|
|
# Make batch length a larger non-multiple of sequence length |
|
|
|
trainer.trainer_parameters["batch_size"] = 100 |
|
|
|
trainer._update_policy() |
|
|
|
|
|
|
|
|
|
|
|
def test_process_trajectory(dummy_config): |
|
|
|
|
|
|
policy = mock.Mock() |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
trainer.add_policy(brain_params, policy) |
|
|
|
|
|
|
|
|
|
|
|
def test_bad_config(dummy_config): |
|
|
|
brain_params = make_brain_parameters( |
|
|
|
discrete_action=False, visual_inputs=0, vec_obs_size=6 |
|
|
|
) |
|
|
|
# Test that we throw an error if we have sequence length greater than batch size |
|
|
|
dummy_config["sequence_length"] = 64 |
|
|
|
dummy_config["batch_size"] = 32 |
|
|
|
dummy_config["use_recurrent"] = True |
|
|
|
with pytest.raises(UnityTrainerException): |
|
|
|
_ = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |