Ruo-Ping Dong
4 年前
当前提交
79d89158
共有 15 个文件被更改,包括 1100 次插入 和 168 次删除
-
6ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
31ml-agents/mlagents/trainers/policy/torch_policy.py
-
1ml-agents/mlagents/trainers/ppo/trainer.py
-
4ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
11ml-agents/mlagents/trainers/torch/decoders.py
-
69ml-agents/mlagents/trainers/torch/distributions.py
-
41ml-agents/mlagents/trainers/torch/encoders.py
-
389ml-agents/mlagents/trainers/torch/networks.py
-
58ml-agents/mlagents/trainers/torch/utils.py
-
210ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
31ml-agents/mlagents/trainers/tests/torch/test_decoders.py
-
141ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
110ml-agents/mlagents/trainers/tests/torch/test_encoders.py
-
166ml-agents/mlagents/trainers/tests/torch/test_utils.py
|
|||
import pytest |
|||
|
|||
import torch |
|||
from mlagents.trainers.torch.networks import ( |
|||
NetworkBody, |
|||
ValueNetwork, |
|||
SimpleActor, |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
) |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_networkbody_vector(): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
sample_act = torch.ones((1, 2)) |
|||
|
|||
for _ in range(100): |
|||
encoded, _ = networkbody([sample_obs], [], sample_act) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_lstm(): |
|||
obs_size = 4 |
|||
seq_len = 16 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4) |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, seq_len, obs_size)) |
|||
|
|||
for _ in range(100): |
|||
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4)) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_visual(): |
|||
vec_obs_size = 4 |
|||
obs_size = (84, 84, 3) |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(vec_obs_size,), obs_size] |
|||
torch.random.manual_seed(0) |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, 84, 84, 3)) |
|||
sample_vec_obs = torch.ones((1, vec_obs_size)) |
|||
|
|||
for _ in range(100): |
|||
encoded, _ = networkbody([sample_vec_obs], [sample_obs]) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_valuenetwork(): |
|||
obs_size = 4 |
|||
num_outputs = 2 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
value_net = ValueNetwork( |
|||
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs |
|||
) |
|||
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
values, _ = value_net([sample_obs], []) |
|||
loss = 0 |
|||
for s_name in stream_names: |
|||
assert values[s_name].shape == (1, num_outputs) |
|||
# Try to force output to 1 |
|||
loss += torch.nn.functional.mse_loss( |
|||
values[s_name], torch.ones((1, num_outputs)) |
|||
) |
|||
|
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for value in values.values(): |
|||
for _out in value: |
|||
assert _out[0] == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS]) |
|||
def test_simple_actor(action_type): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1)) |
|||
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size) |
|||
# Test get_dist |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
dists, _ = actor.get_dists([sample_obs], [], masks=masks) |
|||
for dist in dists: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist, CategoricalDistInstance) |
|||
|
|||
# Test sample_actions |
|||
actions = actor.sample_action(dists) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == (1, act_size[0]) |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# Test forward |
|||
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
|||
[sample_obs], [], masks=masks |
|||
) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == ( |
|||
act_size[0], |
|||
1, |
|||
) # This is different from above for ONNX export |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# TODO: Once export works properly. fix the shapes here. |
|||
assert mem_size == 0 |
|||
assert is_cont == int(action_type == ActionType.CONTINUOUS) |
|||
assert act_size_vec == torch.tensor(act_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) |
|||
@pytest.mark.parametrize("lstm", [True, False]) |
|||
def test_actor_critic(ac_type, lstm): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings() if lstm else None |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
actor = ac_type( |
|||
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names |
|||
) |
|||
if lstm: |
|||
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|||
memories = torch.ones( |
|||
( |
|||
1, |
|||
network_settings.memory.sequence_length, |
|||
network_settings.memory.memory_size, |
|||
) |
|||
) |
|||
else: |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
memories = torch.tensor([]) |
|||
# memories isn't always set to None, the network should be able to |
|||
# deal with that. |
|||
# Test critic pass |
|||
value_out = actor.critic_pass([sample_obs], [], memories=memories) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|||
|
|||
# Test get_dist_and_value |
|||
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories) |
|||
for dist in dists: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
|
|||
|
|||
def test_valueheads(): |
|||
stream_names = [f"reward_signal_{num}" for num in range(5)] |
|||
input_size = 5 |
|||
batch_size = 4 |
|||
|
|||
# Test default 1 value per head |
|||
value_heads = ValueHeads(stream_names, input_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size,) |
|||
|
|||
# Test that inputting the wrong size input will throw an error |
|||
with pytest.raises(Exception): |
|||
value_out = value_heads(torch.ones((batch_size, input_size + 2))) |
|||
|
|||
# Test multiple values per head (e.g. discrete Q function) |
|||
output_size = 4 |
|||
value_heads = ValueHeads(stream_names, input_size, output_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out, _ = value_heads(input_data) |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size, output_size) |
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
GaussianDistInstance, |
|||
TanhGaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("tanh_squash", [True, False]) |
|||
@pytest.mark.parametrize("conditional_sigma", [True, False]) |
|||
def test_gaussian_distribution(conditional_sigma, tanh_squash): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = 4 |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = GaussianDistribution( |
|||
hidden_size, |
|||
act_size, |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
|
|||
# Make sure backprop works |
|||
force_action = torch.zeros((1, act_size)) |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
dist_inst = gauss_dist(sample_embedding)[0] |
|||
if tanh_squash: |
|||
assert isinstance(dist_inst, TanhGaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist_inst, GaussianDistInstance) |
|||
log_prob = dist_inst.log_prob(force_action) |
|||
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for prob in log_prob.flatten(): |
|||
assert prob == pytest.approx(-2, abs=0.1) |
|||
|
|||
|
|||
def test_multi_categorical_distribution(): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = [3, 3, 4] |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size) |
|||
|
|||
# Make sure backprop works |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
def create_test_prob(size: int) -> torch.Tensor: |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)] |
|||
) # High prob for first action |
|||
return test_prob.log() |
|||
|
|||
for _ in range(100): |
|||
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size)))) |
|||
loss = 0 |
|||
for i, dist_inst in enumerate(dist_insts): |
|||
assert isinstance(dist_inst, CategoricalDistInstance) |
|||
log_prob = dist_inst.all_log_prob() |
|||
test_log_prob = create_test_prob(act_size[i]) |
|||
# Force log_probs to match the high probability for the first action generated by |
|||
# create_test_prob |
|||
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for dist_inst, size in zip(dist_insts, act_size): |
|||
# Check that the log probs are close to the fake ones that we generated. |
|||
test_log_probs = create_test_prob(size) |
|||
for _prob, _test_prob in zip( |
|||
dist_inst.all_log_prob().flatten().tolist(), |
|||
test_log_probs.flatten().tolist(), |
|||
): |
|||
assert _prob == pytest.approx(_test_prob, abs=0.1) |
|||
|
|||
# Test masks |
|||
masks = [] |
|||
for branch in act_size: |
|||
masks += [0] * (branch - 1) + [1] |
|||
masks = torch.tensor([masks]) |
|||
dist_insts = gauss_dist(sample_embedding, masks=masks) |
|||
for dist_inst in dist_insts: |
|||
log_prob = dist_inst.all_log_prob() |
|||
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001) |
|||
|
|||
|
|||
def test_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = GaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in dist_instance.entropy().flatten(): |
|||
# entropy of standard normal at 0 |
|||
assert ent == pytest.approx(2.83, abs=0.01) |
|||
|
|||
|
|||
def test_tanh_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = GaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
assert torch.max(action) < 1.0 and torch.min(action) > -1.0 |
|||
|
|||
|
|||
def test_categorical_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
test_prob = torch.tensor( |
|||
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1) |
|||
) # High prob for first action |
|||
dist_instance = CategoricalDistInstance(test_prob) |
|||
|
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1,) |
|||
assert action < act_size |
|||
|
|||
# Make sure the first action as higher probability than the others. |
|||
prob_first_action = dist_instance.log_prob(torch.tensor([0])) |
|||
|
|||
for i in range(1, act_size): |
|||
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action |
|
|||
import torch |
|||
from unittest import mock |
|||
import pytest |
|||
|
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
Normalizer, |
|||
SimpleVisualEncoder, |
|||
ResNetVisualEncoder, |
|||
NatureVisualEncoder, |
|||
) |
|||
|
|||
|
|||
# This test will also reveal issues with states not being saved in the state_dict. |
|||
def compare_models(module_1, module_2): |
|||
is_same = True |
|||
for key_item_1, key_item_2 in zip( |
|||
module_1.state_dict().items(), module_2.state_dict().items() |
|||
): |
|||
# Compare tensors in state_dict and not the keys. |
|||
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same |
|||
return is_same |
|||
|
|||
|
|||
def test_normalizer(): |
|||
input_size = 2 |
|||
norm = Normalizer(input_size) |
|||
|
|||
# These three inputs should mean to 0.5, and variance 2 |
|||
# with the steps starting at 1 |
|||
vec_input1 = torch.tensor([[1, 1]]) |
|||
vec_input2 = torch.tensor([[1, 1]]) |
|||
vec_input3 = torch.tensor([[0, 0]]) |
|||
norm.update(vec_input1) |
|||
norm.update(vec_input2) |
|||
norm.update(vec_input3) |
|||
|
|||
# Test normalization |
|||
for val in norm(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
# Test copy normalization |
|||
norm2 = Normalizer(input_size) |
|||
assert not compare_models(norm, norm2) |
|||
norm2.copy_from(norm) |
|||
assert compare_models(norm, norm2) |
|||
for val in norm2(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = False |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
output = vector_encoder(torch.ones((1, input_size))) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
normalize = True |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
new_vec = torch.ones((1, input_size)) |
|||
vector_encoder.update_normalization(new_vec) |
|||
|
|||
mock_normalizer.assert_called_with(input_size) |
|||
mock_normalizer_inst.update.assert_called_with(new_vec) |
|||
|
|||
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
vector_encoder.copy_normalization(vector_encoder2) |
|||
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_and_unnormalized_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
unnormalized_size = 32 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = True |
|||
mock_normalizer_inst.return_value = torch.ones((1, input_size)) |
|||
vector_encoder = VectorAndUnnormalizedInputEncoder( |
|||
input_size, hidden_size, unnormalized_size, num_layers, normalize |
|||
) |
|||
# Make sure normalizer is only called on input_size |
|||
mock_normalizer.assert_called_with(input_size) |
|||
normal_input = torch.ones((1, input_size)) |
|||
|
|||
unnormalized_input = torch.ones((1, 32)) |
|||
output = vector_encoder(normal_input, unnormalized_input) |
|||
mock_normalizer_inst.assert_called_with(normal_input) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)]) |
|||
@pytest.mark.parametrize( |
|||
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder] |
|||
) |
|||
def test_visual_encoder(vis_class, image_size): |
|||
num_outputs = 128 |
|||
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs) |
|||
# Note: NCHW not NHWC |
|||
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1])) |
|||
encoding = enc(sample_input) |
|||
assert encoding.shape == (1, num_outputs) |
|
|||
import pytest |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.settings import EncoderType |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
) |
|||
from mlagents.trainers.torch.distributions import ( |
|||
CategoricalDistInstance, |
|||
GaussianDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_min_visual_size(): |
|||
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER |
|||
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType) |
|||
|
|||
for encoder_type in EncoderType: |
|||
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] |
|||
vis_input = torch.ones((1, 3, good_size, good_size)) |
|||
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type) |
|||
enc_func = ModelUtils.get_encoder_for_type(encoder_type) |
|||
enc = enc_func(good_size, good_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
# Anything under the min size should raise an exception. If not, decrease the min size! |
|||
with pytest.raises(Exception): |
|||
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1 |
|||
vis_input = torch.ones((1, 3, bad_size, bad_size)) |
|||
|
|||
with pytest.raises(UnityTrainerException): |
|||
# Make sure we'd hit a friendly error during model setup time. |
|||
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type) |
|||
|
|||
enc = enc_func(bad_size, bad_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
|
|||
@pytest.mark.parametrize("unnormalized_inputs", [0, 1]) |
|||
@pytest.mark.parametrize("num_visual", [0, 1, 2]) |
|||
@pytest.mark.parametrize("num_vector", [0, 1, 2]) |
|||
@pytest.mark.parametrize("normalize", [True, False]) |
|||
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN]) |
|||
def test_create_encoders( |
|||
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs |
|||
): |
|||
vec_obs_shape = (5,) |
|||
vis_obs_shape = (84, 84, 3) |
|||
obs_shapes = [] |
|||
for _ in range(num_vector): |
|||
obs_shapes.append(vec_obs_shape) |
|||
for _ in range(num_visual): |
|||
obs_shapes.append(vis_obs_shape) |
|||
h_size = 128 |
|||
num_layers = 3 |
|||
unnormalized_inputs = 1 |
|||
vis_enc, vec_enc = ModelUtils.create_encoders( |
|||
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize |
|||
) |
|||
vec_enc = list(vec_enc) |
|||
vis_enc = list(vis_enc) |
|||
assert len(vec_enc) == ( |
|||
1 if unnormalized_inputs + num_vector > 0 else 0 |
|||
) # There's always at most one vector encoder. |
|||
assert len(vis_enc) == num_visual |
|||
|
|||
if unnormalized_inputs > 0: |
|||
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder) |
|||
elif num_vector > 0: |
|||
assert isinstance(vec_enc[0], VectorEncoder) |
|||
|
|||
for enc in vis_enc: |
|||
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) |
|||
|
|||
|
|||
def test_list_to_tensor(): |
|||
# Test converting pure list |
|||
unconverted_list = [[1, 2], [1, 3], [1, 4]] |
|||
tensor = ModelUtils.list_to_tensor(unconverted_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting pure numpy array |
|||
np_list = np.asarray(unconverted_list) |
|||
tensor = ModelUtils.list_to_tensor(np_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting list of numpy arrays |
|||
list_of_np = [np.asarray(_el) for _el in unconverted_list] |
|||
tensor = ModelUtils.list_to_tensor(list_of_np) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
|
|||
def test_break_into_branches(): |
|||
# Test normal multi-branch case |
|||
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]]) |
|||
action_size = [2, 1, 3] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(action_size) == len(broken_actions) |
|||
for i, _action in enumerate(broken_actions): |
|||
assert _action.shape == (1, action_size[i]) |
|||
|
|||
# Test 1-branch case |
|||
action_size = [6] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(broken_actions) == 1 |
|||
assert broken_actions[0].shape == (1, 6) |
|||
|
|||
|
|||
def test_actions_to_onehot(): |
|||
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]]) |
|||
action_size = [2, 1, 3] |
|||
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size) |
|||
expected_result = [ |
|||
torch.tensor([[0, 1], [0, 1]]), |
|||
torch.tensor([[1], [1]]), |
|||
torch.tensor([[0, 0, 1], [0, 0, 1]]), |
|||
] |
|||
for res, exp in zip(oh_actions, expected_result): |
|||
assert torch.equal(res, exp) |
|||
|
|||
|
|||
def test_get_probs_and_entropy(): |
|||
# Test continuous |
|||
# Add two dists to the list. This isn't done in the code but we'd like to support it. |
|||
dist_list = [ |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
] |
|||
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert log_probs.shape == (1, 2, 2) |
|||
assert entropies.shape == (1, 2, 2) |
|||
assert all_probs is None |
|||
|
|||
for log_prob in log_probs.flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in entropies.flatten(): |
|||
# entropy of standard normal at 0 |
|||
assert ent == pytest.approx(2.83, abs=0.01) |
|||
|
|||
# Test continuous |
|||
# Add two dists to the list. |
|||
act_size = 2 |
|||
test_prob = torch.tensor( |
|||
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1) |
|||
) # High prob for first action |
|||
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)] |
|||
action_list = [torch.tensor([0]), torch.tensor([1])] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert all_probs.shape == (len(dist_list * act_size),) |
|||
assert entropies.shape == (len(dist_list),) |
|||
# Make sure the first action has high probability than the others. |
|||
assert log_probs.flatten()[0] > log_probs.flatten()[1] |
撰写
预览
正在加载...
取消
保存
Reference in new issue