Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

84 行
2.7 KiB

from mlagents.torch_utils import torch
from unittest import mock
import pytest
from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
)
# This test will also reveal issues with states not being saved in the state_dict.
def compare_models(module_1, module_2):
is_same = True
for key_item_1, key_item_2 in zip(
module_1.state_dict().items(), module_2.state_dict().items()
):
# Compare tensors in state_dict and not the keys.
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same
return is_same
def test_normalizer():
input_size = 2
norm = Normalizer(input_size)
# These three inputs should mean to 0.5, and variance 2
# with the steps starting at 1
vec_input1 = torch.tensor([[1, 1]])
vec_input2 = torch.tensor([[1, 1]])
vec_input3 = torch.tensor([[0, 0]])
norm.update(vec_input1)
norm.update(vec_input2)
norm.update(vec_input3)
# Test normalization
for val in norm(vec_input1)[0].tolist():
assert val == pytest.approx(0.707, abs=0.001)
# Test copy normalization
norm2 = Normalizer(input_size)
assert not compare_models(norm, norm2)
norm2.copy_from(norm)
assert compare_models(norm, norm2)
for val in norm2(vec_input1)[0].tolist():
assert val == pytest.approx(0.707, abs=0.001)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
normalize = False
vector_encoder = VectorInput(input_size, normalize)
output = vector_encoder(torch.ones((1, input_size)))
assert output.shape == (1, input_size)
normalize = True
vector_encoder = VectorInput(input_size, normalize)
new_vec = torch.ones((1, input_size))
vector_encoder.update_normalization(new_vec)
mock_normalizer.assert_called_with(input_size)
mock_normalizer_inst.update.assert_called_with(new_vec)
vector_encoder2 = VectorInput(input_size, normalize)
vector_encoder.copy_normalization(vector_encoder2)
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst)
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
# Note: NCHW not NHWC
sample_input = torch.ones((1, image_size[0], image_size[1], image_size[2]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)