浏览代码

added test in test_unityagents.py for curriculum class

/tag-0.2.0
vincentpierre 7 年前
当前提交
1bbaf0dd
共有 1 个文件被更改,包括 57 次插入1 次删除
  1. 58
      python/test_unityagents.py

58
python/test_unityagents.py


import socket
import mock
import struct
import json
from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, BrainInfo, BrainParameters
from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, BrainInfo, BrainParameters, Curriculum
def append_length(input):

env.close()
assert not env._loaded
mock_socket.close.assert_called_once()
dummy_curriculum= json.loads('''{
"measure" : "reward",
"thresholds" : [10, 20, 50],
"min_lesson_length" : 3,
"signal_smoothing" : true,
"parameters" :
{
"param1" : [0.7, 0.5, 0.3, 0.1],
"param2" : [100, 50, 20, 15],
"param3" : [0.2, 0.3, 0.7, 0.9]
}
}''')
bad_curriculum= json.loads('''{
"measure" : "reward",
"thresholds" : [10, 20, 50],
"min_lesson_length" : 3,
"signal_smoothing" : false,
"parameters" :
{
"param1" : [0.7, 0.5, 0.3, 0.1],
"param2" : [100, 50, 20],
"param3" : [0.2, 0.3, 0.7, 0.9]
}
}''')
def test_curriculum():
open_name = '%s.open' % __name__
with mock.patch('json.load') as mock_load:
with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = 0
mock_load.return_value = bad_curriculum
with pytest.raises(UnityEnvironmentException):
curriculum = Curriculum('test_unityagents.py', {"param1":1,"param2":1,"param3":1})
mock_load.return_value = dummy_curriculum
with pytest.raises(UnityEnvironmentException):
curriculum = Curriculum('test_unityagents.py', {"param1":1,"param2":1})
curriculum = Curriculum('test_unityagents.py', {"param1":1,"param2":1,"param3":1})
assert curriculum.get_lesson_number() == 0
curriculum.set_lesson_number(1)
assert curriculum.get_lesson_number() == 1
curriculum.get_lesson(10)
assert curriculum.get_lesson_number() == 1
curriculum.get_lesson(30)
curriculum.get_lesson(30)
assert curriculum.get_lesson_number() == 1
assert curriculum.lesson_length == 3
assert curriculum.get_lesson(30) == {'param1': 0.3, 'param2': 20, 'param3': 0.7}
assert curriculum.lesson_length == 0
assert curriculum.get_lesson_number() == 2
正在加载...
取消
保存