|
|
|
|
|
|
import pytest |
|
|
|
import mlagents.trainers.tests.mock_brain as mb |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import os |
|
|
|
|
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
|
|
|
default_num_epoch=3, |
|
|
|
) |
|
|
|
return bc_module |
|
|
|
|
|
|
|
|
|
|
|
def assert_stats_are_float(stats): |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, float) |
|
|
|
|
|
|
|
|
|
|
|
# Test default values |
|
|
|
|
|
|
) |
|
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
assert_stats_are_float(stats) |
|
|
|
|
|
|
|
|
|
|
|
# Test with constant pretraining learning rate |
|
|
|
|
|
|
) |
|
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
assert_stats_are_float(stats) |
|
|
|
old_learning_rate = bc_module.current_lr |
|
|
|
|
|
|
|
_ = bc_module.update() |
|
|
|
|
|
|
) |
|
|
|
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
assert_stats_are_float(stats) |
|
|
|
|
|
|
|
|
|
|
|
# Test with discrete control and visual observations |
|
|
|
|
|
|
) |
|
|
|
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
assert_stats_are_float(stats) |
|
|
|
|
|
|
|
|
|
|
|
# Test with discrete control, visual observations and RNN |
|
|
|
|
|
|
) |
|
|
|
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|
|
|
stats = bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
assert_stats_are_float(stats) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |