您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
29 行
1.1 KiB
29 行
1.1 KiB
import os
|
|
|
|
from mlagents.trainers.demo_loader import load_demonstration, make_demo_buffer
|
|
|
|
|
|
def test_load_demo():
|
|
path_prefix = os.path.dirname(os.path.abspath(__file__))
|
|
brain_parameters, brain_infos, total_expected = load_demonstration(
|
|
path_prefix + "/test.demo"
|
|
)
|
|
assert brain_parameters.brain_name == "Ball3DBrain"
|
|
assert brain_parameters.vector_observation_space_size == 8
|
|
assert len(brain_infos) == total_expected
|
|
|
|
demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
|
|
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1
|
|
|
|
|
|
def test_load_demo_dir():
|
|
path_prefix = os.path.dirname(os.path.abspath(__file__))
|
|
brain_parameters, brain_infos, total_expected = load_demonstration(
|
|
path_prefix + "/test_demo_dir"
|
|
)
|
|
assert brain_parameters.brain_name == "Ball3DBrain"
|
|
assert brain_parameters.vector_observation_space_size == 8
|
|
assert len(brain_infos) == total_expected
|
|
|
|
demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
|
|
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1
|