浏览代码

[tests] Add tests for core PyTorch files (#4292)

/develop/add-fire
GitHub 4 年前
当前提交
5bcbef8d
共有 7 个文件被更改,包括 482 次插入24 次删除
  1. 30
      ml-agents/mlagents/trainers/torch/distributions.py
  2. 13
      ml-agents/mlagents/trainers/torch/encoders.py
  3. 15
      ml-agents/mlagents/trainers/torch/utils.py
  4. 31
      ml-agents/mlagents/trainers/tests/torch/test_decoders.py
  5. 141
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  6. 110
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  7. 166
      ml-agents/mlagents/trainers/tests/torch/test_utils.py

30
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
import torch
from torch import nn
import numpy as np

class GaussianDistribution(nn.Module):
def __init__(
self,
hidden_size,
num_outputs,
conditional_sigma=False,
tanh_squash=False,
**kwargs
hidden_size: int,
num_outputs: int,
conditional_sigma: bool = False,
tanh_squash: bool = False,
super().__init__(**kwargs)
super().__init__()
self.conditional_sigma = conditional_sigma
self.mu = nn.Linear(hidden_size, num_outputs)
self.tanh_squash = tanh_squash

torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)

class MultiCategoricalDistribution(nn.Module):
def __init__(self, hidden_size, act_sizes):
def __init__(self, hidden_size: int, act_sizes: List[int]):
self.branches = self.create_policy_branches(hidden_size)
self.branches = self._create_policy_branches(hidden_size)
def create_policy_branches(self, hidden_size):
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
branches = []
for size in self.act_sizes:
branch_output_layer = nn.Linear(hidden_size, size)

def mask_branch(self, logits, mask):
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def split_masks(self, masks):
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []
for idx, _ in enumerate(self.act_sizes):
start = int(np.sum(self.act_sizes[:idx]))

def forward(self, inputs, masks):
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
masks = self.split_masks(masks)
masks = self._split_masks(masks)
norm_logits = self.mask_branch(logits, masks[idx])
norm_logits = self._mask_branch(logits, masks[idx])
distribution = CategoricalDistInstance(norm_logits)
branch_distributions.append(distribution)
return branch_distributions

13
ml-agents/mlagents/trainers/torch/encoders.py


class Normalizer(nn.Module):
def __init__(self, vec_obs_size: int):
super().__init__()
self.normalization_steps = torch.tensor(1)
self.running_mean = torch.zeros(vec_obs_size)
self.running_variance = torch.ones(vec_obs_size)
self.register_buffer("normalization_steps", torch.tensor(1))
self.register_buffer("running_mean", torch.zeros(vec_obs_size))
self.register_buffer("running_variance", torch.ones(vec_obs_size))
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
normalized_state = torch.clamp(

new_variance = self.running_variance + (
input_to_new_mean * input_to_old_mean
).sum(0)
self.running_mean = new_mean
self.running_variance = new_variance
self.normalization_steps = total_new_steps
# Update in-place
self.running_mean.data.copy_(new_mean.data)
self.running_variance.data.copy_(new_variance.data)
self.normalization_steps.data.copy_(total_new_steps.data)
def copy_from(self, other_normalizer: "Normalizer") -> None:
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)

15
ml-agents/mlagents/trainers/torch/utils.py


@staticmethod
def _check_resolution_for_encoder(
vis_in: torch.Tensor, vis_encoder_type: EncoderType
height: int, width: int, vis_encoder_type: EncoderType
height = vis_in.shape[1]
width = vis_in.shape[2]
if height < min_res or width < min_res:
raise UnityTrainerException(
f"Visual observation resolution ({width}x{height}) is too small for"

vector_size = 0
for i, dimension in enumerate(observation_shapes):
if len(dimension) == 3:
ModelUtils._check_resolution_for_encoder(
dimension[0], dimension[1], vis_encode_type
)
visual_encoders.append(
visual_encoder_class(
dimension[0], dimension[1], dimension[2], h_size

def actions_to_onehot(
discrete_actions: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
"""
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
action.
:param discrete_actions: Actions in integer form.
:param action_size: List of branch sizes. Should be of same size as discrete_actions'
last dimension.
:return: List of one-hot tensors, one representing each branch.
"""
onehot_branches = [
torch.nn.functional.one_hot(_act.T, action_size[i])
for i, _act in enumerate(discrete_actions.T)

31
ml-agents/mlagents/trainers/tests/torch/test_decoders.py


import pytest
import torch
from mlagents.trainers.torch.decoders import ValueHeads
def test_valueheads():
stream_names = [f"reward_signal_{num}" for num in range(5)]
input_size = 5
batch_size = 4
# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)
# Test that inputting the wrong size input will throw an error
with pytest.raises(Exception):
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
# Test multiple values per head (e.g. discrete Q function)
output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)

141
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


import pytest
import torch
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
GaussianDistInstance,
TanhGaussianDistInstance,
CategoricalDistInstance,
)
@pytest.mark.parametrize("tanh_squash", [True, False])
@pytest.mark.parametrize("conditional_sigma", [True, False])
def test_gaussian_distribution(conditional_sigma, tanh_squash):
torch.manual_seed(0)
hidden_size = 16
act_size = 4
sample_embedding = torch.ones((1, 16))
gauss_dist = GaussianDistribution(
hidden_size,
act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
# Make sure backprop works
force_action = torch.zeros((1, act_size))
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
for _ in range(50):
dist_inst = gauss_dist(sample_embedding)[0]
if tanh_squash:
assert isinstance(dist_inst, TanhGaussianDistInstance)
else:
assert isinstance(dist_inst, GaussianDistInstance)
log_prob = dist_inst.log_prob(force_action)
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
for prob in log_prob.flatten():
assert prob == pytest.approx(-2, abs=0.1)
def test_multi_categorical_distribution():
torch.manual_seed(0)
hidden_size = 16
act_size = [3, 3, 4]
sample_embedding = torch.ones((1, 16))
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)
# Make sure backprop works
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
def create_test_prob(size: int) -> torch.Tensor:
test_prob = torch.tensor(
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
) # High prob for first action
return test_prob.log()
for _ in range(100):
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
loss = 0
for i, dist_inst in enumerate(dist_insts):
assert isinstance(dist_inst, CategoricalDistInstance)
log_prob = dist_inst.all_log_prob()
test_log_prob = create_test_prob(act_size[i])
# Force log_probs to match the high probability for the first action generated by
# create_test_prob
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for dist_inst, size in zip(dist_insts, act_size):
# Check that the log probs are close to the fake ones that we generated.
test_log_probs = create_test_prob(size)
for _prob, _test_prob in zip(
dist_inst.all_log_prob().flatten().tolist(),
test_log_probs.flatten().tolist(),
):
assert _prob == pytest.approx(_test_prob, abs=0.1)
# Test masks
masks = []
for branch in act_size:
masks += [0] * (branch - 1) + [1]
masks = torch.tensor([masks])
dist_insts = gauss_dist(sample_embedding, masks=masks)
for dist_inst in dist_insts:
log_prob = dist_inst.all_log_prob()
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
def test_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
action = dist_instance.sample()
assert action.shape == (1, act_size)
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in dist_instance.entropy().flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
def test_tanh_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1, act_size)
assert torch.max(action) < 1.0 and torch.min(action) > -1.0
def test_categorical_dist_instance():
torch.manual_seed(0)
act_size = 4
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
) # High prob for first action
dist_instance = CategoricalDistInstance(test_prob)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1,)
assert action < act_size
# Make sure the first action as higher probability than the others.
prob_first_action = dist_instance.log_prob(torch.tensor([0]))
for i in range(1, act_size):
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action

110
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


import torch
from unittest import mock
import pytest
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
Normalizer,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
)
# This test will also reveal issues with states not being saved in the state_dict.
def compare_models(module_1, module_2):
is_same = True
for key_item_1, key_item_2 in zip(
module_1.state_dict().items(), module_2.state_dict().items()
):
# Compare tensors in state_dict and not the keys.
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same
return is_same
def test_normalizer():
input_size = 2
norm = Normalizer(input_size)
# These three inputs should mean to 0.5, and variance 2
# with the steps starting at 1
vec_input1 = torch.tensor([[1, 1]])
vec_input2 = torch.tensor([[1, 1]])
vec_input3 = torch.tensor([[0, 0]])
norm.update(vec_input1)
norm.update(vec_input2)
norm.update(vec_input3)
# Test normalization
for val in norm(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
# Test copy normalization
norm2 = Normalizer(input_size)
assert not compare_models(norm, norm2)
norm2.copy_from(norm)
assert compare_models(norm, norm2)
for val in norm2(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
hidden_size = 128
num_layers = 3
normalize = False
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
output = vector_encoder(torch.ones((1, input_size)))
assert output.shape == (1, hidden_size)
normalize = True
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
new_vec = torch.ones((1, input_size))
vector_encoder.update_normalization(new_vec)
mock_normalizer.assert_called_with(input_size)
mock_normalizer_inst.update.assert_called_with(new_vec)
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize)
vector_encoder.copy_normalization(vector_encoder2)
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_and_unnormalized_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
unnormalized_size = 32
hidden_size = 128
num_layers = 3
normalize = True
mock_normalizer_inst.return_value = torch.ones((1, input_size))
vector_encoder = VectorAndUnnormalizedInputEncoder(
input_size, hidden_size, unnormalized_size, num_layers, normalize
)
# Make sure normalizer is only called on input_size
mock_normalizer.assert_called_with(input_size)
normal_input = torch.ones((1, input_size))
unnormalized_input = torch.ones((1, 32))
output = vector_encoder(normal_input, unnormalized_input)
mock_normalizer_inst.assert_called_with(normal_input)
assert output.shape == (1, hidden_size)
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
# Note: NCHW not NHWC
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)

166
ml-agents/mlagents/trainers/tests/torch/test_utils.py


import pytest
import torch
import numpy as np
from mlagents.trainers.settings import EncoderType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
)
from mlagents.trainers.torch.distributions import (
CategoricalDistInstance,
GaussianDistInstance,
)
def test_min_visual_size():
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType)
for encoder_type in EncoderType:
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
vis_input = torch.ones((1, 3, good_size, good_size))
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
enc_func = ModelUtils.get_encoder_for_type(encoder_type)
enc = enc_func(good_size, good_size, 3, 1)
enc.forward(vis_input)
# Anything under the min size should raise an exception. If not, decrease the min size!
with pytest.raises(Exception):
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1
vis_input = torch.ones((1, 3, bad_size, bad_size))
with pytest.raises(UnityTrainerException):
# Make sure we'd hit a friendly error during model setup time.
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
enc = enc_func(bad_size, bad_size, 3, 1)
enc.forward(vis_input)
@pytest.mark.parametrize("unnormalized_inputs", [0, 1])
@pytest.mark.parametrize("num_visual", [0, 1, 2])
@pytest.mark.parametrize("num_vector", [0, 1, 2])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN])
def test_create_encoders(
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs
):
vec_obs_shape = (5,)
vis_obs_shape = (84, 84, 3)
obs_shapes = []
for _ in range(num_vector):
obs_shapes.append(vec_obs_shape)
for _ in range(num_visual):
obs_shapes.append(vis_obs_shape)
h_size = 128
num_layers = 3
unnormalized_inputs = 1
vis_enc, vec_enc = ModelUtils.create_encoders(
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize
)
vec_enc = list(vec_enc)
vis_enc = list(vis_enc)
assert len(vec_enc) == (
1 if unnormalized_inputs + num_vector > 0 else 0
) # There's always at most one vector encoder.
assert len(vis_enc) == num_visual
if unnormalized_inputs > 0:
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder)
elif num_vector > 0:
assert isinstance(vec_enc[0], VectorEncoder)
for enc in vis_enc:
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type))
def test_list_to_tensor():
# Test converting pure list
unconverted_list = [[1, 2], [1, 3], [1, 4]]
tensor = ModelUtils.list_to_tensor(unconverted_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting pure numpy array
np_list = np.asarray(unconverted_list)
tensor = ModelUtils.list_to_tensor(np_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting list of numpy arrays
list_of_np = [np.asarray(_el) for _el in unconverted_list]
tensor = ModelUtils.list_to_tensor(list_of_np)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
def test_break_into_branches():
# Test normal multi-branch case
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]])
action_size = [2, 1, 3]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(action_size) == len(broken_actions)
for i, _action in enumerate(broken_actions):
assert _action.shape == (1, action_size[i])
# Test 1-branch case
action_size = [6]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(broken_actions) == 1
assert broken_actions[0].shape == (1, 6)
def test_actions_to_onehot():
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]])
action_size = [2, 1, 3]
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size)
expected_result = [
torch.tensor([[0, 1], [0, 1]]),
torch.tensor([[1], [1]]),
torch.tensor([[0, 0, 1], [0, 0, 1]]),
]
for res, exp in zip(oh_actions, expected_result):
assert torch.equal(res, exp)
def test_get_probs_and_entropy():
# Test continuous
# Add two dists to the list. This isn't done in the code but we'd like to support it.
dist_list = [
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
]
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert log_probs.shape == (1, 2, 2)
assert entropies.shape == (1, 2, 2)
assert all_probs is None
for log_prob in log_probs.flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in entropies.flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
# Test continuous
# Add two dists to the list.
act_size = 2
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
) # High prob for first action
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)]
action_list = [torch.tensor([0]), torch.tensor([1])]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert all_probs.shape == (len(dist_list * act_size),)
assert entropies.shape == (len(dist_list),)
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
正在加载...
取消
保存