浏览代码
[pytorch] Add decoders, distributions, encoders, layers, networks, and utils (#4349)
/MLA-1734-demo-provider
[pytorch] Add decoders, distributions, encoders, layers, networks, and utils (#4349)
/MLA-1734-demo-provider
GitHub
4 年前
当前提交
e3bc3352
共有 17 个文件被更改,包括 2165 次插入 和 1 次删除
-
2.circleci/config.yml
-
3test_requirements.txt
-
31ml-agents/mlagents/trainers/tests/torch/test_decoders.py
-
141ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
110ml-agents/mlagents/trainers/tests/torch/test_encoders.py
-
40ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
219ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
216ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
0ml-agents/mlagents/trainers/torch/__init__.py
-
23ml-agents/mlagents/trainers/torch/decoders.py
-
206ml-agents/mlagents/trainers/torch/distributions.py
-
302ml-agents/mlagents/trainers/torch/encoders.py
-
84ml-agents/mlagents/trainers/torch/layers.py
-
493ml-agents/mlagents/trainers/torch/networks.py
-
296ml-agents/mlagents/trainers/torch/utils.py
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
|
|||
|
|||
def test_valueheads(): |
|||
stream_names = [f"reward_signal_{num}" for num in range(5)] |
|||
input_size = 5 |
|||
batch_size = 4 |
|||
|
|||
# Test default 1 value per head |
|||
value_heads = ValueHeads(stream_names, input_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out = value_heads(input_data) # Note: mean value will be removed shortly |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size,) |
|||
|
|||
# Test that inputting the wrong size input will throw an error |
|||
with pytest.raises(Exception): |
|||
value_out = value_heads(torch.ones((batch_size, input_size + 2))) |
|||
|
|||
# Test multiple values per head (e.g. discrete Q function) |
|||
output_size = 4 |
|||
value_heads = ValueHeads(stream_names, input_size, output_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out = value_heads(input_data) |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size, output_size) |
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
GaussianDistInstance, |
|||
TanhGaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("tanh_squash", [True, False]) |
|||
@pytest.mark.parametrize("conditional_sigma", [True, False]) |
|||
def test_gaussian_distribution(conditional_sigma, tanh_squash): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = 4 |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = GaussianDistribution( |
|||
hidden_size, |
|||
act_size, |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
|
|||
# Make sure backprop works |
|||
force_action = torch.zeros((1, act_size)) |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
dist_inst = gauss_dist(sample_embedding)[0] |
|||
if tanh_squash: |
|||
assert isinstance(dist_inst, TanhGaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist_inst, GaussianDistInstance) |
|||
log_prob = dist_inst.log_prob(force_action) |
|||
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for prob in log_prob.flatten(): |
|||
assert prob == pytest.approx(-2, abs=0.1) |
|||
|
|||
|
|||
def test_multi_categorical_distribution(): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = [3, 3, 4] |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size) |
|||
|
|||
# Make sure backprop works |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
def create_test_prob(size: int) -> torch.Tensor: |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)] |
|||
) # High prob for first action |
|||
return test_prob.log() |
|||
|
|||
for _ in range(100): |
|||
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size)))) |
|||
loss = 0 |
|||
for i, dist_inst in enumerate(dist_insts): |
|||
assert isinstance(dist_inst, CategoricalDistInstance) |
|||
log_prob = dist_inst.all_log_prob() |
|||
test_log_prob = create_test_prob(act_size[i]) |
|||
# Force log_probs to match the high probability for the first action generated by |
|||
# create_test_prob |
|||
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for dist_inst, size in zip(dist_insts, act_size): |
|||
# Check that the log probs are close to the fake ones that we generated. |
|||
test_log_probs = create_test_prob(size) |
|||
for _prob, _test_prob in zip( |
|||
dist_inst.all_log_prob().flatten().tolist(), |
|||
test_log_probs.flatten().tolist(), |
|||
): |
|||
assert _prob == pytest.approx(_test_prob, abs=0.1) |
|||
|
|||
# Test masks |
|||
masks = [] |
|||
for branch in act_size: |
|||
masks += [0] * (branch - 1) + [1] |
|||
masks = torch.tensor([masks]) |
|||
dist_insts = gauss_dist(sample_embedding, masks=masks) |
|||
for dist_inst in dist_insts: |
|||
log_prob = dist_inst.all_log_prob() |
|||
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001) |
|||
|
|||
|
|||
def test_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = GaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in dist_instance.entropy().flatten(): |
|||
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma) |
|||
assert ent == pytest.approx(1.42, abs=0.01) |
|||
|
|||
|
|||
def test_tanh_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = TanhGaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
assert torch.max(action) < 1.0 and torch.min(action) > -1.0 |
|||
|
|||
|
|||
def test_categorical_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] |
|||
) # High prob for first action |
|||
dist_instance = CategoricalDistInstance(test_prob) |
|||
|
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, 1) |
|||
assert action < act_size |
|||
|
|||
# Make sure the first action as higher probability than the others. |
|||
prob_first_action = dist_instance.log_prob(torch.tensor([0])) |
|||
|
|||
for i in range(1, act_size): |
|||
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action |
|
|||
import torch |
|||
from unittest import mock |
|||
import pytest |
|||
|
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
Normalizer, |
|||
SimpleVisualEncoder, |
|||
ResNetVisualEncoder, |
|||
NatureVisualEncoder, |
|||
) |
|||
|
|||
|
|||
# This test will also reveal issues with states not being saved in the state_dict. |
|||
def compare_models(module_1, module_2): |
|||
is_same = True |
|||
for key_item_1, key_item_2 in zip( |
|||
module_1.state_dict().items(), module_2.state_dict().items() |
|||
): |
|||
# Compare tensors in state_dict and not the keys. |
|||
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same |
|||
return is_same |
|||
|
|||
|
|||
def test_normalizer(): |
|||
input_size = 2 |
|||
norm = Normalizer(input_size) |
|||
|
|||
# These three inputs should mean to 0.5, and variance 2 |
|||
# with the steps starting at 1 |
|||
vec_input1 = torch.tensor([[1, 1]]) |
|||
vec_input2 = torch.tensor([[1, 1]]) |
|||
vec_input3 = torch.tensor([[0, 0]]) |
|||
norm.update(vec_input1) |
|||
norm.update(vec_input2) |
|||
norm.update(vec_input3) |
|||
|
|||
# Test normalization |
|||
for val in norm(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
# Test copy normalization |
|||
norm2 = Normalizer(input_size) |
|||
assert not compare_models(norm, norm2) |
|||
norm2.copy_from(norm) |
|||
assert compare_models(norm, norm2) |
|||
for val in norm2(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = False |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
output = vector_encoder(torch.ones((1, input_size))) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
normalize = True |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
new_vec = torch.ones((1, input_size)) |
|||
vector_encoder.update_normalization(new_vec) |
|||
|
|||
mock_normalizer.assert_called_with(input_size) |
|||
mock_normalizer_inst.update.assert_called_with(new_vec) |
|||
|
|||
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
vector_encoder.copy_normalization(vector_encoder2) |
|||
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_and_unnormalized_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
unnormalized_size = 32 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = True |
|||
mock_normalizer_inst.return_value = torch.ones((1, input_size)) |
|||
vector_encoder = VectorAndUnnormalizedInputEncoder( |
|||
input_size, hidden_size, unnormalized_size, num_layers, normalize |
|||
) |
|||
# Make sure normalizer is only called on input_size |
|||
mock_normalizer.assert_called_with(input_size) |
|||
normal_input = torch.ones((1, input_size)) |
|||
|
|||
unnormalized_input = torch.ones((1, 32)) |
|||
output = vector_encoder(normal_input, unnormalized_input) |
|||
mock_normalizer_inst.assert_called_with(normal_input) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)]) |
|||
@pytest.mark.parametrize( |
|||
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder] |
|||
) |
|||
def test_visual_encoder(vis_class, image_size): |
|||
num_outputs = 128 |
|||
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs) |
|||
# Note: NCHW not NHWC |
|||
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1])) |
|||
encoding = enc(sample_input) |
|||
assert encoding.shape == (1, num_outputs) |
|
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.layers import ( |
|||
Swish, |
|||
linear_layer, |
|||
lstm_layer, |
|||
Initialization, |
|||
) |
|||
|
|||
|
|||
def test_swish(): |
|||
layer = Swish() |
|||
input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) |
|||
target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor)) |
|||
assert torch.all(torch.eq(layer(input_tensor), target_tensor)) |
|||
|
|||
|
|||
def test_initialization_layer(): |
|||
torch.manual_seed(0) |
|||
# Test Zero |
|||
layer = linear_layer( |
|||
3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero |
|||
) |
|||
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) |
|||
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) |
|||
|
|||
|
|||
def test_lstm_layer(): |
|||
torch.manual_seed(0) |
|||
# Test zero for LSTM |
|||
layer = lstm_layer( |
|||
4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero |
|||
) |
|||
for name, param in layer.named_parameters(): |
|||
if "weight" in name: |
|||
assert torch.all(torch.eq(param.data, torch.zeros_like(param.data))) |
|||
elif "bias" in name: |
|||
assert torch.all( |
|||
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8])) |
|||
) |
|
|||
import pytest |
|||
|
|||
import torch |
|||
from mlagents.trainers.torch.networks import ( |
|||
NetworkBody, |
|||
ValueNetwork, |
|||
SimpleActor, |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
) |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_networkbody_vector(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = 0.1 * torch.ones((1, obs_size)) |
|||
sample_act = 0.1 * torch.ones((1, 2)) |
|||
|
|||
for _ in range(300): |
|||
encoded, _ = networkbody([sample_obs], [], sample_act) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_lstm(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
seq_len = 16 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12) |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) |
|||
sample_obs = torch.ones((1, seq_len, obs_size)) |
|||
|
|||
for _ in range(200): |
|||
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12)) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_visual(): |
|||
torch.manual_seed(0) |
|||
vec_obs_size = 4 |
|||
obs_size = (84, 84, 3) |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(vec_obs_size,), obs_size] |
|||
torch.random.manual_seed(0) |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, 84, 84, 3)) |
|||
sample_vec_obs = torch.ones((1, vec_obs_size)) |
|||
|
|||
for _ in range(150): |
|||
encoded, _ = networkbody([sample_vec_obs], [sample_obs]) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_valuenetwork(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
num_outputs = 2 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
value_net = ValueNetwork( |
|||
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs |
|||
) |
|||
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
values, _ = value_net([sample_obs], []) |
|||
loss = 0 |
|||
for s_name in stream_names: |
|||
assert values[s_name].shape == (1, num_outputs) |
|||
# Try to force output to 1 |
|||
loss += torch.nn.functional.mse_loss( |
|||
values[s_name], torch.ones((1, num_outputs)) |
|||
) |
|||
|
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for value in values.values(): |
|||
for _out in value: |
|||
assert _out[0] == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS]) |
|||
def test_simple_actor(action_type): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1)) |
|||
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size) |
|||
# Test get_dist |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
dists, _ = actor.get_dists([sample_obs], [], masks=masks) |
|||
for dist in dists: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist, CategoricalDistInstance) |
|||
|
|||
# Test sample_actions |
|||
actions = actor.sample_action(dists) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == (1, act_size[0]) |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# Test forward |
|||
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
|||
[sample_obs], [], masks=masks |
|||
) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == ( |
|||
act_size[0], |
|||
1, |
|||
) # This is different from above for ONNX export |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# TODO: Once export works properly. fix the shapes here. |
|||
assert mem_size == 0 |
|||
assert is_cont == int(action_type == ActionType.CONTINUOUS) |
|||
assert act_size_vec == torch.tensor(act_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) |
|||
@pytest.mark.parametrize("lstm", [True, False]) |
|||
def test_actor_critic(ac_type, lstm): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings() if lstm else None |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
actor = ac_type( |
|||
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names |
|||
) |
|||
if lstm: |
|||
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|||
memories = torch.ones( |
|||
( |
|||
1, |
|||
network_settings.memory.sequence_length, |
|||
network_settings.memory.memory_size, |
|||
) |
|||
) |
|||
else: |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
memories = torch.tensor([]) |
|||
# memories isn't always set to None, the network should be able to |
|||
# deal with that. |
|||
# Test critic pass |
|||
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
assert memories_out.shape == memories.shape |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|||
|
|||
# Test get_dist_and_value |
|||
dists, value_out, mem_out = actor.get_dist_and_value( |
|||
[sample_obs], [], memories=memories |
|||
) |
|||
if mem_out is not None: |
|||
assert mem_out.shape == memories.shape |
|||
for dist in dists: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|
|||
import pytest |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.settings import EncoderType, ScheduleType |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
) |
|||
from mlagents.trainers.torch.distributions import ( |
|||
CategoricalDistInstance, |
|||
GaussianDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_min_visual_size(): |
|||
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER |
|||
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType) |
|||
|
|||
for encoder_type in EncoderType: |
|||
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] |
|||
vis_input = torch.ones((1, 3, good_size, good_size)) |
|||
ModelUtils._check_resolution_for_encoder(good_size, good_size, encoder_type) |
|||
enc_func = ModelUtils.get_encoder_for_type(encoder_type) |
|||
enc = enc_func(good_size, good_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
# Anything under the min size should raise an exception. If not, decrease the min size! |
|||
with pytest.raises(Exception): |
|||
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1 |
|||
vis_input = torch.ones((1, 3, bad_size, bad_size)) |
|||
|
|||
with pytest.raises(UnityTrainerException): |
|||
# Make sure we'd hit a friendly error during model setup time. |
|||
ModelUtils._check_resolution_for_encoder( |
|||
bad_size, bad_size, encoder_type |
|||
) |
|||
|
|||
enc = enc_func(bad_size, bad_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
|
|||
@pytest.mark.parametrize("unnormalized_inputs", [0, 1]) |
|||
@pytest.mark.parametrize("num_visual", [0, 1, 2]) |
|||
@pytest.mark.parametrize("num_vector", [0, 1, 2]) |
|||
@pytest.mark.parametrize("normalize", [True, False]) |
|||
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN]) |
|||
def test_create_encoders( |
|||
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs |
|||
): |
|||
vec_obs_shape = (5,) |
|||
vis_obs_shape = (84, 84, 3) |
|||
obs_shapes = [] |
|||
for _ in range(num_vector): |
|||
obs_shapes.append(vec_obs_shape) |
|||
for _ in range(num_visual): |
|||
obs_shapes.append(vis_obs_shape) |
|||
h_size = 128 |
|||
num_layers = 3 |
|||
unnormalized_inputs = 1 |
|||
vis_enc, vec_enc = ModelUtils.create_encoders( |
|||
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize |
|||
) |
|||
vec_enc = list(vec_enc) |
|||
vis_enc = list(vis_enc) |
|||
assert len(vec_enc) == ( |
|||
1 if unnormalized_inputs + num_vector > 0 else 0 |
|||
) # There's always at most one vector encoder. |
|||
assert len(vis_enc) == num_visual |
|||
|
|||
if unnormalized_inputs > 0: |
|||
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder) |
|||
elif num_vector > 0: |
|||
assert isinstance(vec_enc[0], VectorEncoder) |
|||
|
|||
for enc in vis_enc: |
|||
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) |
|||
|
|||
|
|||
def test_decayed_value(): |
|||
test_steps = [0, 4, 9] |
|||
# Test constant decay |
|||
param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1]) |
|||
for _step in test_steps: |
|||
_param = param.get_value(_step) |
|||
assert _param == 1.0 |
|||
|
|||
test_results = [1.0, 0.6444, 0.2] |
|||
# Test linear decay |
|||
param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1]) |
|||
for _step, _result in zip(test_steps, test_results): |
|||
_param = param.get_value(_step) |
|||
assert _param == pytest.approx(_result, abs=0.01) |
|||
|
|||
# Test invalid |
|||
with pytest.raises(UnityTrainerException): |
|||
ModelUtils.DecayedValue( |
|||
"SomeOtherSchedule", 1.0, 0.2, test_steps[-1] |
|||
).get_value(0) |
|||
|
|||
|
|||
def test_polynomial_decay(): |
|||
test_steps = [0, 4, 9] |
|||
test_results = [1.0, 0.7, 0.2] |
|||
for _step, _result in zip(test_steps, test_results): |
|||
decayed = ModelUtils.polynomial_decay( |
|||
1.0, 0.2, test_steps[-1], _step, power=0.8 |
|||
) |
|||
assert decayed == pytest.approx(_result, abs=0.01) |
|||
|
|||
|
|||
def test_list_to_tensor(): |
|||
# Test converting pure list |
|||
unconverted_list = [[1, 2], [1, 3], [1, 4]] |
|||
tensor = ModelUtils.list_to_tensor(unconverted_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting pure numpy array |
|||
np_list = np.asarray(unconverted_list) |
|||
tensor = ModelUtils.list_to_tensor(np_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting list of numpy arrays |
|||
list_of_np = [np.asarray(_el) for _el in unconverted_list] |
|||
tensor = ModelUtils.list_to_tensor(list_of_np) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
|
|||
def test_break_into_branches(): |
|||
# Test normal multi-branch case |
|||
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]]) |
|||
action_size = [2, 1, 3] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(action_size) == len(broken_actions) |
|||
for i, _action in enumerate(broken_actions): |
|||
assert _action.shape == (1, action_size[i]) |
|||
|
|||
# Test 1-branch case |
|||
action_size = [6] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(broken_actions) == 1 |
|||
assert broken_actions[0].shape == (1, 6) |
|||
|
|||
|
|||
def test_actions_to_onehot(): |
|||
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]]) |
|||
action_size = [2, 1, 3] |
|||
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size) |
|||
expected_result = [ |
|||
torch.tensor([[0, 1], [0, 1]], dtype=torch.float), |
|||
torch.tensor([[1], [1]], dtype=torch.float), |
|||
torch.tensor([[0, 0, 1], [0, 0, 1]], dtype=torch.float), |
|||
] |
|||
for res, exp in zip(oh_actions, expected_result): |
|||
assert torch.equal(res, exp) |
|||
|
|||
|
|||
def test_get_probs_and_entropy(): |
|||
# Test continuous |
|||
# Add two dists to the list. This isn't done in the code but we'd like to support it. |
|||
dist_list = [ |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
] |
|||
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert log_probs.shape == (1, 2, 2) |
|||
assert entropies.shape == (1, 2, 2) |
|||
assert all_probs is None |
|||
|
|||
for log_prob in log_probs.flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in entropies.flatten(): |
|||
# entropy of standard normal at 0 |
|||
assert ent == pytest.approx(1.42, abs=0.01) |
|||
|
|||
# Test continuous |
|||
# Add two dists to the list. |
|||
act_size = 2 |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] |
|||
) # High prob for first action |
|||
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)] |
|||
action_list = [torch.tensor([0]), torch.tensor([1])] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert all_probs.shape == (1, len(dist_list * act_size)) |
|||
assert entropies.shape == (1, len(dist_list)) |
|||
# Make sure the first action has high probability than the others. |
|||
assert log_probs.flatten()[0] > log_probs.flatten()[1] |
|||
|
|||
|
|||
def test_masked_mean(): |
|||
test_input = torch.tensor([1, 2, 3, 4, 5]) |
|||
masks = torch.ones_like(test_input).bool() |
|||
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|||
assert mean == 3.0 |
|||
|
|||
masks = torch.tensor([False, False, True, True, True]) |
|||
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|||
assert mean == 4.0 |
|||
|
|||
# Make sure it works if all masks are off |
|||
masks = torch.tensor([False, False, False, False, False]) |
|||
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|||
assert mean == 0.0 |
|
|||
from typing import List, Dict |
|||
|
|||
import torch |
|||
from torch import nn |
|||
from mlagents.trainers.torch.layers import linear_layer |
|||
|
|||
|
|||
class ValueHeads(nn.Module): |
|||
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1): |
|||
super().__init__() |
|||
self.stream_names = stream_names |
|||
_value_heads = {} |
|||
|
|||
for name in stream_names: |
|||
value = linear_layer(input_size, output_size) |
|||
_value_heads[name] = value |
|||
self.value_heads = nn.ModuleDict(_value_heads) |
|||
|
|||
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]: |
|||
value_outputs = {} |
|||
for stream_name, head in self.value_heads.items(): |
|||
value_outputs[stream_name] = head(hidden).squeeze(-1) |
|||
return value_outputs |
|
|||
import abc |
|||
from typing import List |
|||
import torch |
|||
from torch import nn |
|||
import numpy as np |
|||
import math |
|||
from mlagents.trainers.torch.layers import linear_layer, Initialization |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class DistInstance(nn.Module, abc.ABC): |
|||
@abc.abstractmethod |
|||
def sample(self) -> torch.Tensor: |
|||
""" |
|||
Return a sample from this distribution. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Returns the log probabilities of a particular value. |
|||
:param value: A value sampled from the distribution. |
|||
:returns: Log probabilities of the given value. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def entropy(self) -> torch.Tensor: |
|||
""" |
|||
Returns the entropy of this distribution. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class DiscreteDistInstance(DistInstance): |
|||
@abc.abstractmethod |
|||
def all_log_prob(self) -> torch.Tensor: |
|||
""" |
|||
Returns the log probabilities of all actions represented by this distribution. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class GaussianDistInstance(DistInstance): |
|||
def __init__(self, mean, std): |
|||
super().__init__() |
|||
self.mean = mean |
|||
self.std = std |
|||
|
|||
def sample(self): |
|||
sample = self.mean + torch.randn_like(self.mean) * self.std |
|||
return sample |
|||
|
|||
def log_prob(self, value): |
|||
var = self.std ** 2 |
|||
log_scale = torch.log(self.std + EPSILON) |
|||
return ( |
|||
-((value - self.mean) ** 2) / (2 * var + EPSILON) |
|||
- log_scale |
|||
- math.log(math.sqrt(2 * math.pi)) |
|||
) |
|||
|
|||
def pdf(self, value): |
|||
log_prob = self.log_prob(value) |
|||
return torch.exp(log_prob) |
|||
|
|||
def entropy(self): |
|||
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON) |
|||
|
|||
|
|||
class TanhGaussianDistInstance(GaussianDistInstance): |
|||
def __init__(self, mean, std): |
|||
super().__init__(mean, std) |
|||
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1) |
|||
|
|||
def sample(self): |
|||
unsquashed_sample = super().sample() |
|||
squashed = self.transform(unsquashed_sample) |
|||
return squashed |
|||
|
|||
def _inverse_tanh(self, value): |
|||
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON) |
|||
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON) |
|||
|
|||
def log_prob(self, value): |
|||
unsquashed = self.transform.inv(value) |
|||
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian( |
|||
unsquashed, value |
|||
) |
|||
|
|||
|
|||
class CategoricalDistInstance(DiscreteDistInstance): |
|||
def __init__(self, logits): |
|||
super().__init__() |
|||
self.logits = logits |
|||
self.probs = torch.softmax(self.logits, dim=-1) |
|||
|
|||
def sample(self): |
|||
return torch.multinomial(self.probs, 1) |
|||
|
|||
def pdf(self, value): |
|||
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), |
|||
# but torch.diag is not supported by ONNX export. |
|||
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) |
|||
return torch.gather( |
|||
self.probs.permute(1, 0)[value.flatten().long()], -1, idx |
|||
).squeeze(-1) |
|||
|
|||
def log_prob(self, value): |
|||
return torch.log(self.pdf(value)) |
|||
|
|||
def all_log_prob(self): |
|||
return torch.log(self.probs) |
|||
|
|||
def entropy(self): |
|||
return -torch.sum(self.probs * torch.log(self.probs), dim=-1) |
|||
|
|||
|
|||
class GaussianDistribution(nn.Module): |
|||
def __init__( |
|||
self, |
|||
hidden_size: int, |
|||
num_outputs: int, |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.conditional_sigma = conditional_sigma |
|||
self.mu = linear_layer( |
|||
hidden_size, |
|||
num_outputs, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
self.tanh_squash = tanh_squash |
|||
if conditional_sigma: |
|||
self.log_sigma = linear_layer( |
|||
hidden_size, |
|||
num_outputs, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
else: |
|||
self.log_sigma = nn.Parameter( |
|||
torch.zeros(1, num_outputs, requires_grad=True) |
|||
) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> List[DistInstance]: |
|||
mu = self.mu(inputs) |
|||
if self.conditional_sigma: |
|||
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2) |
|||
else: |
|||
log_sigma = self.log_sigma |
|||
if self.tanh_squash: |
|||
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))] |
|||
else: |
|||
return [GaussianDistInstance(mu, torch.exp(log_sigma))] |
|||
|
|||
|
|||
class MultiCategoricalDistribution(nn.Module): |
|||
def __init__(self, hidden_size: int, act_sizes: List[int]): |
|||
super().__init__() |
|||
self.act_sizes = act_sizes |
|||
self.branches = self._create_policy_branches(hidden_size) |
|||
|
|||
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList: |
|||
branches = [] |
|||
for size in self.act_sizes: |
|||
branch_output_layer = linear_layer( |
|||
hidden_size, |
|||
size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
branches.append(branch_output_layer) |
|||
return nn.ModuleList(branches) |
|||
|
|||
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|||
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask |
|||
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1) |
|||
normalized_logits = torch.log(normalized_probs + EPSILON) |
|||
return normalized_logits |
|||
|
|||
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]: |
|||
split_masks = [] |
|||
for idx, _ in enumerate(self.act_sizes): |
|||
start = int(np.sum(self.act_sizes[:idx])) |
|||
end = int(np.sum(self.act_sizes[: idx + 1])) |
|||
split_masks.append(masks[:, start:end]) |
|||
return split_masks |
|||
|
|||
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]: |
|||
# Todo - Support multiple branches in mask code |
|||
branch_distributions = [] |
|||
masks = self._split_masks(masks) |
|||
for idx, branch in enumerate(self.branches): |
|||
logits = branch(inputs) |
|||
norm_logits = self._mask_branch(logits, masks[idx]) |
|||
distribution = CategoricalDistInstance(norm_logits) |
|||
branch_distributions.append(distribution) |
|||
return branch_distributions |
|
|||
from typing import Tuple, Optional, Union |
|||
|
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
|
|||
class Normalizer(nn.Module): |
|||
def __init__(self, vec_obs_size: int): |
|||
super().__init__() |
|||
self.register_buffer("normalization_steps", torch.tensor(1)) |
|||
self.register_buffer("running_mean", torch.zeros(vec_obs_size)) |
|||
self.register_buffer("running_variance", torch.ones(vec_obs_size)) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|||
normalized_state = torch.clamp( |
|||
(inputs - self.running_mean) |
|||
/ torch.sqrt(self.running_variance / self.normalization_steps), |
|||
-5, |
|||
5, |
|||
) |
|||
return normalized_state |
|||
|
|||
def update(self, vector_input: torch.Tensor) -> None: |
|||
steps_increment = vector_input.size()[0] |
|||
total_new_steps = self.normalization_steps + steps_increment |
|||
|
|||
input_to_old_mean = vector_input - self.running_mean |
|||
new_mean = self.running_mean + (input_to_old_mean / total_new_steps).sum(0) |
|||
|
|||
input_to_new_mean = vector_input - new_mean |
|||
new_variance = self.running_variance + ( |
|||
input_to_new_mean * input_to_old_mean |
|||
).sum(0) |
|||
# Update in-place |
|||
self.running_mean.data.copy_(new_mean.data) |
|||
self.running_variance.data.copy_(new_variance.data) |
|||
self.normalization_steps.data.copy_(total_new_steps.data) |
|||
|
|||
def copy_from(self, other_normalizer: "Normalizer") -> None: |
|||
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data) |
|||
self.running_mean.data.copy_(other_normalizer.running_mean.data) |
|||
self.running_variance.copy_(other_normalizer.running_variance.data) |
|||
|
|||
|
|||
def conv_output_shape( |
|||
h_w: Tuple[int, int], |
|||
kernel_size: Union[int, Tuple[int, int]] = 1, |
|||
stride: int = 1, |
|||
padding: int = 0, |
|||
dilation: int = 1, |
|||
) -> Tuple[int, int]: |
|||
""" |
|||
Calculates the output shape (height and width) of the output of a convolution layer. |
|||
kernel_size, stride, padding and dilation correspond to the inputs of the |
|||
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) |
|||
:param h_w: The height and width of the input. |
|||
:param kernel_size: The size of the kernel of the convolution (can be an int or a |
|||
tuple [width, height]) |
|||
:param stride: The stride of the convolution |
|||
:param padding: The padding of the convolution |
|||
:param dilation: The dilation of the convolution |
|||
""" |
|||
from math import floor |
|||
|
|||
if not isinstance(kernel_size, tuple): |
|||
kernel_size = (int(kernel_size), int(kernel_size)) |
|||
h = floor( |
|||
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 |
|||
) |
|||
w = floor( |
|||
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 |
|||
) |
|||
return h, w |
|||
|
|||
|
|||
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: |
|||
""" |
|||
Calculates the output shape (height and width) of the output of a max pooling layer. |
|||
kernel_size corresponds to the inputs of the |
|||
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) |
|||
:param kernel_size: The size of the kernel of the convolution |
|||
""" |
|||
height = (h_w[0] - kernel_size) // 2 + 1 |
|||
width = (h_w[1] - kernel_size) // 2 + 1 |
|||
return height, width |
|||
|
|||
|
|||
class VectorEncoder(nn.Module): |
|||
def __init__( |
|||
self, |
|||
input_size: int, |
|||
hidden_size: int, |
|||
num_layers: int, |
|||
normalize: bool = False, |
|||
): |
|||
self.normalizer: Optional[Normalizer] = None |
|||
super().__init__() |
|||
self.layers = [ |
|||
linear_layer( |
|||
input_size, |
|||
hidden_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
] |
|||
self.layers.append(Swish()) |
|||
if normalize: |
|||
self.normalizer = Normalizer(input_size) |
|||
|
|||
for _ in range(num_layers - 1): |
|||
self.layers.append( |
|||
linear_layer( |
|||
hidden_size, |
|||
hidden_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
) |
|||
self.layers.append(Swish()) |
|||
self.seq_layers = nn.Sequential(*self.layers) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|||
if self.normalizer is not None: |
|||
inputs = self.normalizer(inputs) |
|||
return self.seq_layers(inputs) |
|||
|
|||
def copy_normalization(self, other_encoder: "VectorEncoder") -> None: |
|||
if self.normalizer is not None and other_encoder.normalizer is not None: |
|||
self.normalizer.copy_from(other_encoder.normalizer) |
|||
|
|||
def update_normalization(self, inputs: torch.Tensor) -> None: |
|||
if self.normalizer is not None: |
|||
self.normalizer.update(inputs) |
|||
|
|||
|
|||
class VectorAndUnnormalizedInputEncoder(VectorEncoder): |
|||
""" |
|||
Encoder for concatenated vector input (can be normalized) and unnormalized vector input. |
|||
This is used for passing inputs to the network that should not be normalized, such as |
|||
actions in the case of a Q function or task parameterizations. It will result in an encoder with |
|||
this structure: |
|||
____________ ____________ ____________ |
|||
| Vector | | Normalize | | Fully | |
|||
| | --> | | --> | Connected | ___________ |
|||
|____________| |____________| | | | Output | |
|||
____________ | | --> | | |
|||
|Unnormalized| | | |___________| |
|||
| Input | ---------------------> | | |
|||
|____________| |____________| |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
input_size: int, |
|||
hidden_size: int, |
|||
unnormalized_input_size: int, |
|||
num_layers: int, |
|||
normalize: bool = False, |
|||
): |
|||
super().__init__( |
|||
input_size + unnormalized_input_size, |
|||
hidden_size, |
|||
num_layers, |
|||
normalize=False, |
|||
) |
|||
if normalize: |
|||
self.normalizer = Normalizer(input_size) |
|||
else: |
|||
self.normalizer = None |
|||
|
|||
def forward( # pylint: disable=W0221 |
|||
self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None |
|||
) -> None: |
|||
if unnormalized_inputs is None: |
|||
raise UnityTrainerException( |
|||
"Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input." |
|||
) # Fix mypy errors about method parameters. |
|||
if self.normalizer is not None: |
|||
inputs = self.normalizer(inputs) |
|||
return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1)) |
|||
|
|||
|
|||
class SimpleVisualEncoder(nn.Module): |
|||
def __init__( |
|||
self, height: int, width: int, initial_channels: int, output_size: int |
|||
): |
|||
super().__init__() |
|||
self.h_size = output_size |
|||
conv_1_hw = conv_output_shape((height, width), 8, 4) |
|||
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) |
|||
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32 |
|||
|
|||
self.conv_layers = nn.Sequential( |
|||
nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(16, 32, [4, 4], [2, 2]), |
|||
nn.LeakyReLU(), |
|||
) |
|||
self.dense = nn.Sequential( |
|||
linear_layer( |
|||
self.final_flat, |
|||
self.h_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
), |
|||
nn.LeakyReLU(), |
|||
) |
|||
|
|||
def forward(self, visual_obs: torch.Tensor) -> None: |
|||
hidden = self.conv_layers(visual_obs) |
|||
hidden = torch.reshape(hidden, (-1, self.final_flat)) |
|||
hidden = self.dense(hidden) |
|||
return hidden |
|||
|
|||
|
|||
class NatureVisualEncoder(nn.Module): |
|||
def __init__(self, height, width, initial_channels, output_size): |
|||
super().__init__() |
|||
self.h_size = output_size |
|||
conv_1_hw = conv_output_shape((height, width), 8, 4) |
|||
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) |
|||
conv_3_hw = conv_output_shape(conv_2_hw, 3, 1) |
|||
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64 |
|||
|
|||
self.conv_layers = nn.Sequential( |
|||
nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(32, 64, [4, 4], [2, 2]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(64, 64, [3, 3], [1, 1]), |
|||
nn.LeakyReLU(), |
|||
) |
|||
self.dense = nn.Sequential( |
|||
linear_layer( |
|||
self.final_flat, |
|||
self.h_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
), |
|||
nn.LeakyReLU(), |
|||
) |
|||
|
|||
def forward(self, visual_obs: torch.Tensor) -> None: |
|||
hidden = self.conv_layers(visual_obs) |
|||
hidden = hidden.view([-1, self.final_flat]) |
|||
hidden = self.dense(hidden) |
|||
return hidden |
|||
|
|||
|
|||
class ResNetBlock(nn.Module): |
|||
def __init__(self, channel: int): |
|||
""" |
|||
Creates a ResNet Block. |
|||
:param channel: The number of channels in the input (and output) tensors of the |
|||
convolutions |
|||
""" |
|||
super().__init__() |
|||
self.layers = nn.Sequential( |
|||
Swish(), |
|||
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|||
Swish(), |
|||
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|||
) |
|||
|
|||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: |
|||
return input_tensor + self.layers(input_tensor) |
|||
|
|||
|
|||
class ResNetVisualEncoder(nn.Module): |
|||
def __init__(self, height, width, initial_channels, final_hidden): |
|||
super().__init__() |
|||
n_channels = [16, 32, 32] # channel for each stack |
|||
n_blocks = 2 # number of residual blocks |
|||
self.layers = [] |
|||
last_channel = initial_channels |
|||
for _, channel in enumerate(n_channels): |
|||
self.layers.append( |
|||
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1) |
|||
) |
|||
self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|||
height, width = pool_out_shape((height, width), 3) |
|||
for _ in range(n_blocks): |
|||
self.layers.append(ResNetBlock(channel)) |
|||
last_channel = channel |
|||
self.layers.append(Swish()) |
|||
self.dense = linear_layer( |
|||
n_channels[-1] * height * width, |
|||
final_hidden, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
|
|||
def forward(self, visual_obs): |
|||
batch_size = visual_obs.shape[0] |
|||
hidden = visual_obs |
|||
for layer in self.layers: |
|||
hidden = layer(hidden) |
|||
before_out = hidden.view(batch_size, -1) |
|||
return torch.relu(self.dense(before_out)) |
|
|||
import torch |
|||
from enum import Enum |
|||
|
|||
|
|||
class Swish(torch.nn.Module): |
|||
def forward(self, data: torch.Tensor) -> torch.Tensor: |
|||
return torch.mul(data, torch.sigmoid(data)) |
|||
|
|||
|
|||
class Initialization(Enum): |
|||
Zero = 0 |
|||
XavierGlorotNormal = 1 |
|||
XavierGlorotUniform = 2 |
|||
KaimingHeNormal = 3 # also known as Variance scaling |
|||
KaimingHeUniform = 4 |
|||
|
|||
|
|||
_init_methods = { |
|||
Initialization.Zero: torch.zero_, |
|||
Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_, |
|||
Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_, |
|||
Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_, |
|||
Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_, |
|||
} |
|||
|
|||
|
|||
def linear_layer( |
|||
input_size: int, |
|||
output_size: int, |
|||
kernel_init: Initialization = Initialization.XavierGlorotUniform, |
|||
kernel_gain: float = 1.0, |
|||
bias_init: Initialization = Initialization.Zero, |
|||
) -> torch.nn.Module: |
|||
""" |
|||
Creates a torch.nn.Linear module and initializes its weights. |
|||
:param input_size: The size of the input tensor |
|||
:param output_size: The size of the output tensor |
|||
:param kernel_init: The Initialization to use for the weights of the layer |
|||
:param kernel_gain: The multiplier for the weights of the kernel. Note that in |
|||
TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling |
|||
KaimingHeNormal with kernel_gain of 0.1 |
|||
:param bias_init: The Initialization to use for the weights of the bias layer |
|||
""" |
|||
layer = torch.nn.Linear(input_size, output_size) |
|||
_init_methods[kernel_init](layer.weight.data) |
|||
layer.weight.data *= kernel_gain |
|||
_init_methods[bias_init](layer.bias.data) |
|||
return layer |
|||
|
|||
|
|||
def lstm_layer( |
|||
input_size: int, |
|||
hidden_size: int, |
|||
num_layers: int = 1, |
|||
batch_first: bool = True, |
|||
forget_bias: float = 1.0, |
|||
kernel_init: Initialization = Initialization.XavierGlorotUniform, |
|||
bias_init: Initialization = Initialization.Zero, |
|||
) -> torch.nn.Module: |
|||
""" |
|||
Creates a torch.nn.LSTM and initializes its weights and biases. Provides a |
|||
forget_bias offset like is done in TensorFlow. |
|||
""" |
|||
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first) |
|||
# Add forget_bias to forget gate bias |
|||
for name, param in lstm.named_parameters(): |
|||
# Each weight and bias is a concatenation of 4 matrices |
|||
if "weight" in name: |
|||
for idx in range(4): |
|||
block_size = param.shape[0] // 4 |
|||
_init_methods[kernel_init]( |
|||
param.data[idx * block_size : (idx + 1) * block_size] |
|||
) |
|||
if "bias" in name: |
|||
for idx in range(4): |
|||
block_size = param.shape[0] // 4 |
|||
_init_methods[bias_init]( |
|||
param.data[idx * block_size : (idx + 1) * block_size] |
|||
) |
|||
if idx == 1: |
|||
param.data[idx * block_size : (idx + 1) * block_size].add_( |
|||
forget_bias |
|||
) |
|||
return lstm |
|
|||
from typing import Callable, List, Dict, Tuple, Optional |
|||
import attr |
|||
import abc |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
DistInstance, |
|||
) |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
from mlagents.trainers.torch.layers import lstm_layer |
|||
|
|||
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|||
EncoderFunction = Callable[ |
|||
[torch.Tensor, int, ActivationFunction, int, str, bool], torch.Tensor |
|||
] |
|||
|
|||
EPSILON = 1e-7 |
|||
|
|||
|
|||
class NetworkBody(nn.Module): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
encoded_act_size: int = 0, |
|||
): |
|||
super().__init__() |
|||
self.normalize = network_settings.normalize |
|||
self.use_lstm = network_settings.memory is not None |
|||
self.h_size = network_settings.hidden_units |
|||
self.m_size = ( |
|||
network_settings.memory.memory_size |
|||
if network_settings.memory is not None |
|||
else 0 |
|||
) |
|||
|
|||
self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders( |
|||
observation_shapes, |
|||
self.h_size, |
|||
network_settings.num_layers, |
|||
network_settings.vis_encode_type, |
|||
unnormalized_inputs=encoded_act_size, |
|||
normalize=self.normalize, |
|||
) |
|||
|
|||
if self.use_lstm: |
|||
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) |
|||
else: |
|||
self.lstm = None |
|||
|
|||
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: |
|||
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): |
|||
vec_enc.update_normalization(vec_input) |
|||
|
|||
def copy_normalization(self, other_network: "NetworkBody") -> None: |
|||
if self.normalize: |
|||
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|||
n1.copy_normalization(n2) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
vec_encodes = [] |
|||
for idx, encoder in enumerate(self.vector_encoders): |
|||
vec_input = vec_inputs[idx] |
|||
if actions is not None: |
|||
hidden = encoder(vec_input, actions) |
|||
else: |
|||
hidden = encoder(vec_input) |
|||
vec_encodes.append(hidden) |
|||
|
|||
vis_encodes = [] |
|||
for idx, encoder in enumerate(self.visual_encoders): |
|||
vis_input = vis_inputs[idx] |
|||
vis_input = vis_input.permute([0, 3, 1, 2]) |
|||
hidden = encoder(vis_input) |
|||
vis_encodes.append(hidden) |
|||
|
|||
if len(vec_encodes) > 0 and len(vis_encodes) > 0: |
|||
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|||
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|||
encoding = torch.stack( |
|||
[vec_encodes_tensor, vis_encodes_tensor], dim=-1 |
|||
).sum(dim=-1) |
|||
elif len(vec_encodes) > 0: |
|||
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|||
elif len(vis_encodes) > 0: |
|||
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|||
else: |
|||
raise Exception("No valid inputs to network.") |
|||
|
|||
if self.use_lstm: |
|||
# Resize to (batch, sequence length, encoding size) |
|||
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|||
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|||
encoding, memories = self.lstm(encoding, memories) |
|||
encoding = encoding.reshape([-1, self.m_size // 2]) |
|||
memories = torch.cat(memories, dim=-1) |
|||
return encoding, memories |
|||
|
|||
|
|||
class ValueNetwork(nn.Module): |
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
encoded_act_size: int = 0, |
|||
outputs_per_stream: int = 1, |
|||
): |
|||
|
|||
# This is not a typo, we want to call __init__ of nn.Module |
|||
nn.Module.__init__(self) |
|||
self.network_body = NetworkBody( |
|||
observation_shapes, network_settings, encoded_act_size=encoded_act_size |
|||
) |
|||
if network_settings.memory is not None: |
|||
encoding_size = network_settings.memory.memory_size // 2 |
|||
else: |
|||
encoding_size = network_settings.hidden_units |
|||
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, actions, memories, sequence_length |
|||
) |
|||
output = self.value_heads(encoding) |
|||
return output, memories |
|||
|
|||
|
|||
class Actor(abc.ABC): |
|||
@abc.abstractmethod |
|||
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|||
""" |
|||
Updates normalization of Actor based on the provided List of vector obs. |
|||
:param vector_obs: A List of vector obs as tensors. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a List of Distribution iinstances and samples an action from each. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def get_dists( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|||
""" |
|||
Returns distributions from this Actor, from which actions can be sampled. |
|||
If memory is enabled, return the memories as well. |
|||
:param vec_inputs: A List of vector inputs as tensors. |
|||
:param vis_inputs: A List of visual inputs as tensors. |
|||
:param masks: If using discrete actions, a Tensor of action masks. |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
:return: A Tuple of a List of action distribution instances, and memories. |
|||
Memories will be None if not using memory. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: |
|||
""" |
|||
Forward pass of the Actor for inference. This is required for export to ONNX, and |
|||
the inputs and outputs of this method should not be changed without a respective change |
|||
in the ONNX export code. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class ActorCritic(Actor): |
|||
@abc.abstractmethod |
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
""" |
|||
Get value outputs for the given obs. |
|||
:param vec_inputs: List of vector inputs as tensors. |
|||
:param vis_inputs: List of visual inputs as tensors. |
|||
:param memories: Tensor of memories, if using memory. Otherwise, None. |
|||
:returns: Dict of reward stream to output tensor for values. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
""" |
|||
Returns distributions, from which actions can be sampled, and value estimates. |
|||
If memory is enabled, return the memories as well. |
|||
:param vec_inputs: A List of vector inputs as tensors. |
|||
:param vis_inputs: A List of visual inputs as tensors. |
|||
:param masks: If using discrete actions, a Tensor of action masks. |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
:return: A Tuple of a List of action distribution instances, a Dict of reward signal |
|||
name to value estimate, and memories. Memories will be None if not using memory. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class SimpleActor(nn.Module, Actor): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.act_type = act_type |
|||
self.act_size = act_size |
|||
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|||
self.memory_size = torch.nn.Parameter(torch.Tensor([0])) |
|||
self.is_continuous_int = torch.nn.Parameter( |
|||
torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) |
|||
) |
|||
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size)) |
|||
self.network_body = NetworkBody(observation_shapes, network_settings) |
|||
if network_settings.memory is not None: |
|||
self.encoding_size = network_settings.memory.memory_size // 2 |
|||
else: |
|||
self.encoding_size = network_settings.hidden_units |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
self.distribution = GaussianDistribution( |
|||
self.encoding_size, |
|||
act_size[0], |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
else: |
|||
self.distribution = MultiCategoricalDistribution( |
|||
self.encoding_size, act_size |
|||
) |
|||
|
|||
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|||
self.network_body.update_normalization(vector_obs) |
|||
|
|||
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|||
actions = [] |
|||
for action_dist in dists: |
|||
action = action_dist.sample() |
|||
actions.append(action) |
|||
return actions |
|||
|
|||
def get_dists( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|||
) |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
dists = self.distribution(encoding) |
|||
else: |
|||
dists = self.distribution(encoding, masks) |
|||
|
|||
return dists, memories |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: |
|||
""" |
|||
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|||
""" |
|||
dists, _ = self.get_dists( |
|||
vec_inputs, vis_inputs, masks, memories, sequence_length |
|||
) |
|||
action_list = self.sample_action(dists) |
|||
sampled_actions = torch.stack(action_list, dim=-1) |
|||
return ( |
|||
sampled_actions, |
|||
dists[0].pdf(sampled_actions), |
|||
self.version_number, |
|||
self.memory_size, |
|||
self.is_continuous_int, |
|||
self.act_size_vector, |
|||
) |
|||
|
|||
|
|||
class SharedActorCritic(SimpleActor, ActorCritic): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
stream_names: List[str], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__( |
|||
observation_shapes, |
|||
network_settings, |
|||
act_type, |
|||
act_size, |
|||
conditional_sigma, |
|||
tanh_squash, |
|||
) |
|||
self.stream_names = stream_names |
|||
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|||
|
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
encoding, memories_out = self.network_body( |
|||
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|||
) |
|||
return self.value_heads(encoding), memories_out |
|||
|
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|||
) |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
dists = self.distribution(encoding) |
|||
else: |
|||
dists = self.distribution(encoding, masks=masks) |
|||
|
|||
value_outputs = self.value_heads(encoding) |
|||
return dists, value_outputs, memories |
|||
|
|||
|
|||
class SeparateActorCritic(SimpleActor, ActorCritic): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
stream_names: List[str], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
# Give the Actor only half the memories. Note we previously validate |
|||
# that memory_size must be a multiple of 4. |
|||
self.use_lstm = network_settings.memory is not None |
|||
if network_settings.memory is not None: |
|||
self.half_mem_size = network_settings.memory.memory_size // 2 |
|||
new_memory_settings = attr.evolve( |
|||
network_settings.memory, memory_size=self.half_mem_size |
|||
) |
|||
use_network_settings = attr.evolve( |
|||
network_settings, memory=new_memory_settings |
|||
) |
|||
else: |
|||
use_network_settings = network_settings |
|||
self.half_mem_size = 0 |
|||
super().__init__( |
|||
observation_shapes, |
|||
use_network_settings, |
|||
act_type, |
|||
act_size, |
|||
conditional_sigma, |
|||
tanh_squash, |
|||
) |
|||
self.stream_names = stream_names |
|||
self.critic = ValueNetwork( |
|||
stream_names, observation_shapes, use_network_settings |
|||
) |
|||
|
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
actor_mem, critic_mem = None, None |
|||
if self.use_lstm: |
|||
# Use only the back half of memories for critic |
|||
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) |
|||
value_outputs, critic_mem_out = self.critic( |
|||
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|||
) |
|||
if actor_mem is not None: |
|||
# Make memories with the actor mem unchanged |
|||
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|||
else: |
|||
memories_out = None |
|||
return value_outputs, memories_out |
|||
|
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
if self.use_lstm: |
|||
# Use only the back half of memories for critic and actor |
|||
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) |
|||
else: |
|||
critic_mem = None |
|||
actor_mem = None |
|||
dists, actor_mem_outs = self.get_dists( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
memories=actor_mem, |
|||
sequence_length=sequence_length, |
|||
masks=masks, |
|||
) |
|||
value_outputs, critic_mem_outs = self.critic( |
|||
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|||
) |
|||
if self.use_lstm: |
|||
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|||
else: |
|||
mem_out = None |
|||
return dists, value_outputs, mem_out |
|||
|
|||
|
|||
class GlobalSteps(nn.Module): |
|||
def __init__(self): |
|||
super().__init__() |
|||
self.global_step = torch.Tensor([0]) |
|||
|
|||
def increment(self, value): |
|||
self.global_step += value |
|||
|
|||
|
|||
class LearningRate(nn.Module): |
|||
def __init__(self, lr): |
|||
# Todo: add learning rate decay |
|||
super().__init__() |
|||
self.learning_rate = torch.Tensor([lr]) |
|
|||
from typing import List, Optional, Tuple |
|||
import torch |
|||
import numpy as np |
|||
from torch import nn |
|||
|
|||
from mlagents.trainers.torch.encoders import ( |
|||
SimpleVisualEncoder, |
|||
ResNetVisualEncoder, |
|||
NatureVisualEncoder, |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
) |
|||
from mlagents.trainers.settings import EncoderType, ScheduleType |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance |
|||
|
|||
|
|||
class ModelUtils: |
|||
# Minimum supported side for each encoder type. If refactoring an encoder, please |
|||
# adjust these also. |
|||
MIN_RESOLUTION_FOR_ENCODER = { |
|||
EncoderType.SIMPLE: 20, |
|||
EncoderType.NATURE_CNN: 36, |
|||
EncoderType.RESNET: 15, |
|||
} |
|||
|
|||
class ActionFlattener: |
|||
def __init__(self, behavior_spec: BehaviorSpec): |
|||
self._specs = behavior_spec |
|||
|
|||
@property |
|||
def flattened_size(self) -> int: |
|||
if self._specs.is_action_continuous(): |
|||
return self._specs.action_size |
|||
else: |
|||
return sum(self._specs.discrete_action_branches) |
|||
|
|||
def forward(self, action: torch.Tensor) -> torch.Tensor: |
|||
if self._specs.is_action_continuous(): |
|||
return action |
|||
else: |
|||
return torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
torch.as_tensor(action, dtype=torch.long), |
|||
self._specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
|
|||
@staticmethod |
|||
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: |
|||
""" |
|||
Apply a learning rate to a torch optimizer. |
|||
:param optim: Optimizer |
|||
:param lr: Learning rate |
|||
""" |
|||
for param_group in optim.param_groups: |
|||
param_group["lr"] = lr |
|||
|
|||
class DecayedValue: |
|||
def __init__( |
|||
self, |
|||
schedule: ScheduleType, |
|||
initial_value: float, |
|||
min_value: float, |
|||
max_step: int, |
|||
): |
|||
""" |
|||
Object that represnets value of a parameter that should be decayed, assuming it is a function of |
|||
global_step. |
|||
:param schedule: Type of learning rate schedule. |
|||
:param initial_value: Initial value before decay. |
|||
:param min_value: Decay value to this value by max_step. |
|||
:param max_step: The final step count where the return value should equal min_value. |
|||
:param global_step: The current step count. |
|||
:return: The value. |
|||
""" |
|||
self.schedule = schedule |
|||
self.initial_value = initial_value |
|||
self.min_value = min_value |
|||
self.max_step = max_step |
|||
|
|||
def get_value(self, global_step: int) -> float: |
|||
""" |
|||
Get the value at a given global step. |
|||
:param global_step: Step count. |
|||
:returns: Decayed value at this global step. |
|||
""" |
|||
if self.schedule == ScheduleType.CONSTANT: |
|||
return self.initial_value |
|||
elif self.schedule == ScheduleType.LINEAR: |
|||
return ModelUtils.polynomial_decay( |
|||
self.initial_value, self.min_value, self.max_step, global_step |
|||
) |
|||
else: |
|||
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") |
|||
|
|||
@staticmethod |
|||
def polynomial_decay( |
|||
initial_value: float, |
|||
min_value: float, |
|||
max_step: int, |
|||
global_step: int, |
|||
power: float = 1.0, |
|||
) -> float: |
|||
""" |
|||
Get a decayed value based on a polynomial schedule, with respect to the current global step. |
|||
:param initial_value: Initial value before decay. |
|||
:param min_value: Decay value to this value by max_step. |
|||
:param max_step: The final step count where the return value should equal min_value. |
|||
:param global_step: The current step count. |
|||
:param power: Power of polynomial decay. 1.0 (default) is a linear decay. |
|||
:return: The current decayed value. |
|||
""" |
|||
global_step = min(global_step, max_step) |
|||
decayed_value = (initial_value - min_value) * ( |
|||
1 - float(global_step) / max_step |
|||
) ** (power) + min_value |
|||
return decayed_value |
|||
|
|||
@staticmethod |
|||
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: |
|||
ENCODER_FUNCTION_BY_TYPE = { |
|||
EncoderType.SIMPLE: SimpleVisualEncoder, |
|||
EncoderType.NATURE_CNN: NatureVisualEncoder, |
|||
EncoderType.RESNET: ResNetVisualEncoder, |
|||
} |
|||
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type) |
|||
|
|||
@staticmethod |
|||
def _check_resolution_for_encoder( |
|||
height: int, width: int, vis_encoder_type: EncoderType |
|||
) -> None: |
|||
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type] |
|||
if height < min_res or width < min_res: |
|||
raise UnityTrainerException( |
|||
f"Visual observation resolution ({width}x{height}) is too small for" |
|||
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}" |
|||
) |
|||
|
|||
@staticmethod |
|||
def create_encoders( |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
h_size: int, |
|||
num_layers: int, |
|||
vis_encode_type: EncoderType, |
|||
unnormalized_inputs: int = 0, |
|||
normalize: bool = False, |
|||
) -> Tuple[nn.ModuleList, nn.ModuleList]: |
|||
""" |
|||
Creates visual and vector encoders, along with their normalizers. |
|||
:param observation_shapes: List of Tuples that represent the action dimensions. |
|||
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for |
|||
conditioining network on other values (e.g. actions for a Q function) |
|||
:param h_size: Number of hidden units per layer. |
|||
:param num_layers: Depth of MLP per encoder. |
|||
:param vis_encode_type: Type of visual encoder to use. |
|||
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector |
|||
obs. |
|||
:param normalize: Normalize all vector inputs. |
|||
:return: Tuple of visual encoders and vector encoders each as a list. |
|||
""" |
|||
visual_encoders: List[nn.Module] = [] |
|||
vector_encoders: List[nn.Module] = [] |
|||
|
|||
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type) |
|||
vector_size = 0 |
|||
for i, dimension in enumerate(observation_shapes): |
|||
if len(dimension) == 3: |
|||
ModelUtils._check_resolution_for_encoder( |
|||
dimension[0], dimension[1], vis_encode_type |
|||
) |
|||
visual_encoders.append( |
|||
visual_encoder_class( |
|||
dimension[0], dimension[1], dimension[2], h_size |
|||
) |
|||
) |
|||
elif len(dimension) == 1: |
|||
vector_size += dimension[0] |
|||
else: |
|||
raise UnityTrainerException( |
|||
f"Unsupported shape of {dimension} for observation {i}" |
|||
) |
|||
if vector_size + unnormalized_inputs > 0: |
|||
if unnormalized_inputs > 0: |
|||
vector_encoders.append( |
|||
VectorAndUnnormalizedInputEncoder( |
|||
vector_size, h_size, unnormalized_inputs, num_layers, normalize |
|||
) |
|||
) |
|||
else: |
|||
vector_encoders.append( |
|||
VectorEncoder(vector_size, h_size, num_layers, normalize) |
|||
) |
|||
return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders) |
|||
|
|||
@staticmethod |
|||
def list_to_tensor( |
|||
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = None |
|||
) -> torch.Tensor: |
|||
""" |
|||
Converts a list of numpy arrays into a tensor. MUCH faster than |
|||
calling as_tensor on the list directly. |
|||
""" |
|||
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype) |
|||
|
|||
@staticmethod |
|||
def break_into_branches( |
|||
concatenated_logits: torch.Tensor, action_size: List[int] |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a concatenated set of logits that represent multiple discrete action branches |
|||
and breaks it up into one Tensor per branch. |
|||
:param concatenated_logits: Tensor that represents the concatenated action branches |
|||
:param action_size: List of ints containing the number of possible actions for each branch. |
|||
:return: A List of Tensors containing one tensor per branch. |
|||
""" |
|||
action_idx = [0] + list(np.cumsum(action_size)) |
|||
branched_logits = [ |
|||
concatenated_logits[:, action_idx[i] : action_idx[i + 1]] |
|||
for i in range(len(action_size)) |
|||
] |
|||
return branched_logits |
|||
|
|||
@staticmethod |
|||
def actions_to_onehot( |
|||
discrete_actions: torch.Tensor, action_size: List[int] |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each |
|||
action. |
|||
:param discrete_actions: Actions in integer form. |
|||
:param action_size: List of branch sizes. Should be of same size as discrete_actions' |
|||
last dimension. |
|||
:return: List of one-hot tensors, one representing each branch. |
|||
""" |
|||
onehot_branches = [ |
|||
torch.nn.functional.one_hot(_act.T, action_size[i]).float() |
|||
for i, _act in enumerate(discrete_actions.long().T) |
|||
] |
|||
return onehot_branches |
|||
|
|||
@staticmethod |
|||
def dynamic_partition( |
|||
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Torch implementation of dynamic_partition : |
|||
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition |
|||
Splits the data Tensor input into num_partitions Tensors according to the indices in |
|||
partitions. |
|||
:param data: The Tensor data that will be split into partitions. |
|||
:param partitions: An indices tensor that determines in which partition each element |
|||
of data will be in. |
|||
:param num_partitions: The number of partitions to output. Corresponds to the |
|||
maximum possible index in the partitions argument. |
|||
:return: A list of Tensor partitions (Their indices correspond to their partition index). |
|||
""" |
|||
res: List[torch.Tensor] = [] |
|||
for i in range(num_partitions): |
|||
res += [data[(partitions == i).nonzero().squeeze(1)]] |
|||
return res |
|||
|
|||
@staticmethod |
|||
def get_probs_and_entropy( |
|||
action_list: List[torch.Tensor], dists: List[DistInstance] |
|||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|||
log_probs_list = [] |
|||
all_probs_list = [] |
|||
entropies_list = [] |
|||
for action, action_dist in zip(action_list, dists): |
|||
log_prob = action_dist.log_prob(action) |
|||
log_probs_list.append(log_prob) |
|||
entropies_list.append(action_dist.entropy()) |
|||
if isinstance(action_dist, DiscreteDistInstance): |
|||
all_probs_list.append(action_dist.all_log_prob()) |
|||
log_probs = torch.stack(log_probs_list, dim=-1) |
|||
entropies = torch.stack(entropies_list, dim=-1) |
|||
if not all_probs_list: |
|||
log_probs = log_probs.squeeze(-1) |
|||
entropies = entropies.squeeze(-1) |
|||
all_probs = None |
|||
else: |
|||
all_probs = torch.cat(all_probs_list, dim=-1) |
|||
return log_probs, entropies, all_probs |
|||
|
|||
@staticmethod |
|||
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Returns the mean of the tensor but ignoring the values specified by masks. |
|||
Used for masking out loss functions. |
|||
:param tensor: Tensor which needs mean computation. |
|||
:param masks: Boolean tensor of masks with same dimension as tensor. |
|||
""" |
|||
return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0) |
撰写
预览
正在加载...
取消
保存
Reference in new issue