您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
84 行
2.5 KiB
84 行
2.5 KiB
import os
|
|
import tempfile
|
|
import pytest
|
|
import yaml
|
|
|
|
import mlagents.trainers.tensorflow_to_barracuda as tf2bc
|
|
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
|
|
from mlagents.tf_utils import tf
|
|
from mlagents.model_serialization import SerializationSettings, export_policy_model
|
|
|
|
|
|
def test_barracuda_converter():
|
|
path_prefix = os.path.dirname(os.path.abspath(__file__))
|
|
tmpfile = os.path.join(
|
|
tempfile._get_default_tempdir(), next(tempfile._get_candidate_names()) + ".nn"
|
|
)
|
|
|
|
# make sure there are no left-over files
|
|
if os.path.isfile(tmpfile):
|
|
os.remove(tmpfile)
|
|
|
|
tf2bc.convert(path_prefix + "/BasicLearning.pb", tmpfile)
|
|
|
|
# test if file exists after conversion
|
|
assert os.path.isfile(tmpfile)
|
|
# currently converter produces small output file even if input file is empty
|
|
# 100 bytes is high enough to prove that conversion was successful
|
|
assert os.path.getsize(tmpfile) > 100
|
|
|
|
# cleanup
|
|
os.remove(tmpfile)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_config():
|
|
return yaml.safe_load(
|
|
"""
|
|
trainer: ppo
|
|
batch_size: 32
|
|
beta: 5.0e-3
|
|
buffer_size: 512
|
|
epsilon: 0.2
|
|
hidden_units: 128
|
|
lambd: 0.95
|
|
learning_rate: 3.0e-4
|
|
max_steps: 5.0e4
|
|
normalize: true
|
|
num_epoch: 5
|
|
num_layers: 2
|
|
time_horizon: 64
|
|
sequence_length: 64
|
|
summary_freq: 1000
|
|
use_recurrent: false
|
|
normalize: true
|
|
memory_size: 8
|
|
curiosity_strength: 0.0
|
|
curiosity_enc_size: 1
|
|
output_path: test
|
|
reward_signals:
|
|
extrinsic:
|
|
strength: 1.0
|
|
gamma: 0.99
|
|
"""
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
|
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
|
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
|
|
def test_policy_conversion(dummy_config, tmpdir, rnn, visual, discrete):
|
|
tf.reset_default_graph()
|
|
dummy_config["output_path"] = os.path.join(tmpdir, "test")
|
|
policy = create_policy_mock(
|
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
|
|
)
|
|
policy.save_model(1000)
|
|
settings = SerializationSettings(
|
|
policy.model_path, os.path.join(tmpdir, policy.brain.brain_name)
|
|
)
|
|
export_policy_model(settings, policy.graph, policy.sess)
|
|
|
|
# These checks taken from test_barracuda_converter
|
|
assert os.path.isfile(os.path.join(tmpdir, "test.nn"))
|
|
assert os.path.getsize(os.path.join(tmpdir, "test.nn")) > 100
|