|
|
|
|
|
|
import json |
|
|
|
import yaml |
|
|
|
import unittest.mock as mock |
|
|
|
import pytest |
|
|
|
|
|
|
from unitytrainers.models import * |
|
|
|
from unitytrainers.ppo.trainer import PPOTrainer |
|
|
|
from unitytrainers.bc.trainer import BehavioralCloningTrainer |
|
|
|
from unitytrainers import Curriculum |
|
|
|
from unityagents import UnityEnvironmentException |
|
|
|
from .mock_communicator import MockCommunicator |
|
|
|
|
|
|
|
|
|
|
memory_size: 8 |
|
|
|
''') |
|
|
|
|
|
|
|
dummy_curriculum = json.loads('''{ |
|
|
|
"measure" : "reward", |
|
|
|
"thresholds" : [10, 20, 50], |
|
|
|
"min_lesson_length" : 3, |
|
|
|
"signal_smoothing" : true, |
|
|
|
"parameters" : |
|
|
|
{ |
|
|
|
"param1" : [0.7, 0.5, 0.3, 0.1], |
|
|
|
"param2" : [100, 50, 20, 15], |
|
|
|
"param3" : [0.2, 0.3, 0.7, 0.9] |
|
|
|
} |
|
|
|
}''') |
|
|
|
bad_curriculum = json.loads('''{ |
|
|
|
"measure" : "reward", |
|
|
|
"thresholds" : [10, 20, 50], |
|
|
|
"min_lesson_length" : 3, |
|
|
|
"signal_smoothing" : false, |
|
|
|
"parameters" : |
|
|
|
{ |
|
|
|
"param1" : [0.7, 0.5, 0.3, 0.1], |
|
|
|
"param2" : [100, 50, 20], |
|
|
|
"param3" : [0.2, 0.3, 0.7, 0.9] |
|
|
|
} |
|
|
|
}''') |
|
|
|
|
|
|
|
|
|
|
|
@mock.patch('unityagents.UnityEnvironment.executable_launcher') |
|
|
|
@mock.patch('unityagents.UnityEnvironment.get_communicator') |
|
|
|
|
|
|
batch_size=None, training_length=2) |
|
|
|
assert len(b.update_buffer['action']) == 10 |
|
|
|
assert np.array(b.update_buffer['action']).shape == (10, 2, 2) |
|
|
|
|
|
|
|
|
|
|
|
def test_curriculum(): |
|
|
|
open_name = '%s.open' % __name__ |
|
|
|
with mock.patch('json.load') as mock_load: |
|
|
|
with mock.patch(open_name, create=True) as mock_open: |
|
|
|
mock_open.return_value = 0 |
|
|
|
mock_load.return_value = bad_curriculum |
|
|
|
with pytest.raises(UnityEnvironmentException): |
|
|
|
Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) |
|
|
|
mock_load.return_value = dummy_curriculum |
|
|
|
with pytest.raises(UnityEnvironmentException): |
|
|
|
Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1}) |
|
|
|
curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) |
|
|
|
assert curriculum.get_lesson_number == 0 |
|
|
|
curriculum.set_lesson_number(1) |
|
|
|
assert curriculum.get_lesson_number == 1 |
|
|
|
curriculum.increment_lesson(10) |
|
|
|
assert curriculum.get_lesson_number == 1 |
|
|
|
curriculum.increment_lesson(30) |
|
|
|
curriculum.increment_lesson(30) |
|
|
|
assert curriculum.get_lesson_number == 1 |
|
|
|
assert curriculum.lesson_length == 3 |
|
|
|
curriculum.increment_lesson(30) |
|
|
|
assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7} |
|
|
|
assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} |
|
|
|
assert curriculum.lesson_length == 0 |
|
|
|
assert curriculum.get_lesson_number == 2 |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |