浏览代码

[pytorch] Add decoders, distributions, encoders, layers, networks, and utils (#4349)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
e3bc3352
共有 17 个文件被更改,包括 2165 次插入1 次删除
  1. 2
      .circleci/config.yml
  2. 3
      test_requirements.txt
  3. 31
      ml-agents/mlagents/trainers/tests/torch/test_decoders.py
  4. 141
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  5. 110
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  6. 40
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  7. 219
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  8. 216
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  9. 0
      ml-agents/mlagents/trainers/torch/__init__.py
  10. 23
      ml-agents/mlagents/trainers/torch/decoders.py
  11. 206
      ml-agents/mlagents/trainers/torch/distributions.py
  12. 302
      ml-agents/mlagents/trainers/torch/encoders.py
  13. 84
      ml-agents/mlagents/trainers/torch/layers.py
  14. 493
      ml-agents/mlagents/trainers/torch/networks.py
  15. 296
      ml-agents/mlagents/trainers/torch/utils.py

2
.circleci/config.yml


. venv/bin/activate
mkdir test-reports
pip freeze > test-reports/pip_versions.txt
pytest -n 2 --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
- run:
name: Verify there are no hidden/missing metafiles.

3
test_requirements.txt


pytest-cov==2.6.1
pytest-xdist
# PyTorch tests are here for the time being, before they are used in the codebase.
torch>=1.5.0
# onnx doesn't currently have a wheel for 3.8
tf2onnx>=1.5.5;python_version<'3.8'

31
ml-agents/mlagents/trainers/tests/torch/test_decoders.py


import pytest
import torch
from mlagents.trainers.torch.decoders import ValueHeads
def test_valueheads():
stream_names = [f"reward_signal_{num}" for num in range(5)]
input_size = 5
batch_size = 4
# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)
# Test that inputting the wrong size input will throw an error
with pytest.raises(Exception):
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
# Test multiple values per head (e.g. discrete Q function)
output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)

141
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


import pytest
import torch
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
GaussianDistInstance,
TanhGaussianDistInstance,
CategoricalDistInstance,
)
@pytest.mark.parametrize("tanh_squash", [True, False])
@pytest.mark.parametrize("conditional_sigma", [True, False])
def test_gaussian_distribution(conditional_sigma, tanh_squash):
torch.manual_seed(0)
hidden_size = 16
act_size = 4
sample_embedding = torch.ones((1, 16))
gauss_dist = GaussianDistribution(
hidden_size,
act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
# Make sure backprop works
force_action = torch.zeros((1, act_size))
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
for _ in range(50):
dist_inst = gauss_dist(sample_embedding)[0]
if tanh_squash:
assert isinstance(dist_inst, TanhGaussianDistInstance)
else:
assert isinstance(dist_inst, GaussianDistInstance)
log_prob = dist_inst.log_prob(force_action)
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
for prob in log_prob.flatten():
assert prob == pytest.approx(-2, abs=0.1)
def test_multi_categorical_distribution():
torch.manual_seed(0)
hidden_size = 16
act_size = [3, 3, 4]
sample_embedding = torch.ones((1, 16))
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)
# Make sure backprop works
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
def create_test_prob(size: int) -> torch.Tensor:
test_prob = torch.tensor(
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
) # High prob for first action
return test_prob.log()
for _ in range(100):
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
loss = 0
for i, dist_inst in enumerate(dist_insts):
assert isinstance(dist_inst, CategoricalDistInstance)
log_prob = dist_inst.all_log_prob()
test_log_prob = create_test_prob(act_size[i])
# Force log_probs to match the high probability for the first action generated by
# create_test_prob
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for dist_inst, size in zip(dist_insts, act_size):
# Check that the log probs are close to the fake ones that we generated.
test_log_probs = create_test_prob(size)
for _prob, _test_prob in zip(
dist_inst.all_log_prob().flatten().tolist(),
test_log_probs.flatten().tolist(),
):
assert _prob == pytest.approx(_test_prob, abs=0.1)
# Test masks
masks = []
for branch in act_size:
masks += [0] * (branch - 1) + [1]
masks = torch.tensor([masks])
dist_insts = gauss_dist(sample_embedding, masks=masks)
for dist_inst in dist_insts:
log_prob = dist_inst.all_log_prob()
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
def test_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
action = dist_instance.sample()
assert action.shape == (1, act_size)
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in dist_instance.entropy().flatten():
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma)
assert ent == pytest.approx(1.42, abs=0.01)
def test_tanh_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = TanhGaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1, act_size)
assert torch.max(action) < 1.0 and torch.min(action) > -1.0
def test_categorical_dist_instance():
torch.manual_seed(0)
act_size = 4
test_prob = torch.tensor(
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]
) # High prob for first action
dist_instance = CategoricalDistInstance(test_prob)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1, 1)
assert action < act_size
# Make sure the first action as higher probability than the others.
prob_first_action = dist_instance.log_prob(torch.tensor([0]))
for i in range(1, act_size):
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action

110
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


import torch
from unittest import mock
import pytest
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
Normalizer,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
)
# This test will also reveal issues with states not being saved in the state_dict.
def compare_models(module_1, module_2):
is_same = True
for key_item_1, key_item_2 in zip(
module_1.state_dict().items(), module_2.state_dict().items()
):
# Compare tensors in state_dict and not the keys.
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same
return is_same
def test_normalizer():
input_size = 2
norm = Normalizer(input_size)
# These three inputs should mean to 0.5, and variance 2
# with the steps starting at 1
vec_input1 = torch.tensor([[1, 1]])
vec_input2 = torch.tensor([[1, 1]])
vec_input3 = torch.tensor([[0, 0]])
norm.update(vec_input1)
norm.update(vec_input2)
norm.update(vec_input3)
# Test normalization
for val in norm(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
# Test copy normalization
norm2 = Normalizer(input_size)
assert not compare_models(norm, norm2)
norm2.copy_from(norm)
assert compare_models(norm, norm2)
for val in norm2(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
hidden_size = 128
num_layers = 3
normalize = False
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
output = vector_encoder(torch.ones((1, input_size)))
assert output.shape == (1, hidden_size)
normalize = True
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
new_vec = torch.ones((1, input_size))
vector_encoder.update_normalization(new_vec)
mock_normalizer.assert_called_with(input_size)
mock_normalizer_inst.update.assert_called_with(new_vec)
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize)
vector_encoder.copy_normalization(vector_encoder2)
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_and_unnormalized_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
unnormalized_size = 32
hidden_size = 128
num_layers = 3
normalize = True
mock_normalizer_inst.return_value = torch.ones((1, input_size))
vector_encoder = VectorAndUnnormalizedInputEncoder(
input_size, hidden_size, unnormalized_size, num_layers, normalize
)
# Make sure normalizer is only called on input_size
mock_normalizer.assert_called_with(input_size)
normal_input = torch.ones((1, input_size))
unnormalized_input = torch.ones((1, 32))
output = vector_encoder(normal_input, unnormalized_input)
mock_normalizer_inst.assert_called_with(normal_input)
assert output.shape == (1, hidden_size)
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
# Note: NCHW not NHWC
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)

40
ml-agents/mlagents/trainers/tests/torch/test_layers.py


import torch
from mlagents.trainers.torch.layers import (
Swish,
linear_layer,
lstm_layer,
Initialization,
)
def test_swish():
layer = Swish()
input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]])
target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor))
assert torch.all(torch.eq(layer(input_tensor), target_tensor))
def test_initialization_layer():
torch.manual_seed(0)
# Test Zero
layer = linear_layer(
3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero
)
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data)))
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data)))
def test_lstm_layer():
torch.manual_seed(0)
# Test zero for LSTM
layer = lstm_layer(
4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero
)
for name, param in layer.named_parameters():
if "weight" in name:
assert torch.all(torch.eq(param.data, torch.zeros_like(param.data)))
elif "bias" in name:
assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)

219
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,
SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)
def test_networkbody_vector():
torch.manual_seed(0)
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = 0.1 * torch.ones((1, obs_size))
sample_act = 0.1 * torch.ones((1, 2))
for _ in range(300):
encoded, _ = networkbody([sample_obs], [], sample_act)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_lstm():
torch.manual_seed(0)
obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
)
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
sample_obs = torch.ones((1, seq_len, obs_size))
for _ in range(200):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_visual():
torch.manual_seed(0)
vec_obs_size = 4
obs_size = (84, 84, 3)
network_settings = NetworkSettings()
obs_shapes = [(vec_obs_size,), obs_size]
torch.random.manual_seed(0)
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, 84, 84, 3))
sample_vec_obs = torch.ones((1, vec_obs_size))
for _ in range(150):
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_valuenetwork():
torch.manual_seed(0)
obs_size = 4
num_outputs = 2
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
stream_names = [f"stream_name{n}" for n in range(4)]
value_net = ValueNetwork(
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
)
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)
for _ in range(50):
sample_obs = torch.ones((1, obs_size))
values, _ = value_net([sample_obs], [])
loss = 0
for s_name in stream_names:
assert values[s_name].shape == (1, num_outputs)
# Try to force output to 1
loss += torch.nn.functional.mse_loss(
values[s_name], torch.ones((1, num_outputs))
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for value in values.values():
for _out in value:
assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
act_size = [2]
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
# Test get_dist
sample_obs = torch.ones((1, obs_size))
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
for dist in dists:
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
else:
assert isinstance(dist, CategoricalDistInstance)
# Test sample_actions
actions = actor.sample_action(dists)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
assert act.shape == (1, 1)
# Test forward
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
[sample_obs], [], masks=masks
)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (
act_size[0],
1,
) # This is different from above for ONNX export
else:
assert act.shape == (1, 1)
# TODO: Once export works properly. fix the shapes here.
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert act_size_vec == torch.tensor(act_size)
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
@pytest.mark.parametrize("lstm", [True, False])
def test_actor_critic(ac_type, lstm):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None
)
obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
)
else:
sample_obs = torch.ones((1, obs_size))
memories = torch.tensor([])
# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
assert memories_out.shape == memories.shape
else:
assert value_out[stream].shape == (1,)
# Test get_dist_and_value
dists, value_out, mem_out = actor.get_dist_and_value(
[sample_obs], [], memories=memories
)
if mem_out is not None:
assert mem_out.shape == memories.shape
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)

216
ml-agents/mlagents/trainers/tests/torch/test_utils.py


import pytest
import torch
import numpy as np
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
)
from mlagents.trainers.torch.distributions import (
CategoricalDistInstance,
GaussianDistInstance,
)
def test_min_visual_size():
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType)
for encoder_type in EncoderType:
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
vis_input = torch.ones((1, 3, good_size, good_size))
ModelUtils._check_resolution_for_encoder(good_size, good_size, encoder_type)
enc_func = ModelUtils.get_encoder_for_type(encoder_type)
enc = enc_func(good_size, good_size, 3, 1)
enc.forward(vis_input)
# Anything under the min size should raise an exception. If not, decrease the min size!
with pytest.raises(Exception):
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1
vis_input = torch.ones((1, 3, bad_size, bad_size))
with pytest.raises(UnityTrainerException):
# Make sure we'd hit a friendly error during model setup time.
ModelUtils._check_resolution_for_encoder(
bad_size, bad_size, encoder_type
)
enc = enc_func(bad_size, bad_size, 3, 1)
enc.forward(vis_input)
@pytest.mark.parametrize("unnormalized_inputs", [0, 1])
@pytest.mark.parametrize("num_visual", [0, 1, 2])
@pytest.mark.parametrize("num_vector", [0, 1, 2])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN])
def test_create_encoders(
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs
):
vec_obs_shape = (5,)
vis_obs_shape = (84, 84, 3)
obs_shapes = []
for _ in range(num_vector):
obs_shapes.append(vec_obs_shape)
for _ in range(num_visual):
obs_shapes.append(vis_obs_shape)
h_size = 128
num_layers = 3
unnormalized_inputs = 1
vis_enc, vec_enc = ModelUtils.create_encoders(
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize
)
vec_enc = list(vec_enc)
vis_enc = list(vis_enc)
assert len(vec_enc) == (
1 if unnormalized_inputs + num_vector > 0 else 0
) # There's always at most one vector encoder.
assert len(vis_enc) == num_visual
if unnormalized_inputs > 0:
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder)
elif num_vector > 0:
assert isinstance(vec_enc[0], VectorEncoder)
for enc in vis_enc:
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type))
def test_decayed_value():
test_steps = [0, 4, 9]
# Test constant decay
param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1])
for _step in test_steps:
_param = param.get_value(_step)
assert _param == 1.0
test_results = [1.0, 0.6444, 0.2]
# Test linear decay
param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1])
for _step, _result in zip(test_steps, test_results):
_param = param.get_value(_step)
assert _param == pytest.approx(_result, abs=0.01)
# Test invalid
with pytest.raises(UnityTrainerException):
ModelUtils.DecayedValue(
"SomeOtherSchedule", 1.0, 0.2, test_steps[-1]
).get_value(0)
def test_polynomial_decay():
test_steps = [0, 4, 9]
test_results = [1.0, 0.7, 0.2]
for _step, _result in zip(test_steps, test_results):
decayed = ModelUtils.polynomial_decay(
1.0, 0.2, test_steps[-1], _step, power=0.8
)
assert decayed == pytest.approx(_result, abs=0.01)
def test_list_to_tensor():
# Test converting pure list
unconverted_list = [[1, 2], [1, 3], [1, 4]]
tensor = ModelUtils.list_to_tensor(unconverted_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting pure numpy array
np_list = np.asarray(unconverted_list)
tensor = ModelUtils.list_to_tensor(np_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting list of numpy arrays
list_of_np = [np.asarray(_el) for _el in unconverted_list]
tensor = ModelUtils.list_to_tensor(list_of_np)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
def test_break_into_branches():
# Test normal multi-branch case
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]])
action_size = [2, 1, 3]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(action_size) == len(broken_actions)
for i, _action in enumerate(broken_actions):
assert _action.shape == (1, action_size[i])
# Test 1-branch case
action_size = [6]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(broken_actions) == 1
assert broken_actions[0].shape == (1, 6)
def test_actions_to_onehot():
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]])
action_size = [2, 1, 3]
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size)
expected_result = [
torch.tensor([[0, 1], [0, 1]], dtype=torch.float),
torch.tensor([[1], [1]], dtype=torch.float),
torch.tensor([[0, 0, 1], [0, 0, 1]], dtype=torch.float),
]
for res, exp in zip(oh_actions, expected_result):
assert torch.equal(res, exp)
def test_get_probs_and_entropy():
# Test continuous
# Add two dists to the list. This isn't done in the code but we'd like to support it.
dist_list = [
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
]
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert log_probs.shape == (1, 2, 2)
assert entropies.shape == (1, 2, 2)
assert all_probs is None
for log_prob in log_probs.flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in entropies.flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(1.42, abs=0.01)
# Test continuous
# Add two dists to the list.
act_size = 2
test_prob = torch.tensor(
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]
) # High prob for first action
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)]
action_list = [torch.tensor([0]), torch.tensor([1])]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert all_probs.shape == (1, len(dist_list * act_size))
assert entropies.shape == (1, len(dist_list))
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
def test_masked_mean():
test_input = torch.tensor([1, 2, 3, 4, 5])
masks = torch.ones_like(test_input).bool()
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 3.0
masks = torch.tensor([False, False, True, True, True])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 4.0
# Make sure it works if all masks are off
masks = torch.tensor([False, False, False, False, False])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 0.0

0
ml-agents/mlagents/trainers/torch/__init__.py

23
ml-agents/mlagents/trainers/torch/decoders.py


from typing import List, Dict
import torch
from torch import nn
from mlagents.trainers.torch.layers import linear_layer
class ValueHeads(nn.Module):
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1):
super().__init__()
self.stream_names = stream_names
_value_heads = {}
for name in stream_names:
value = linear_layer(input_size, output_size)
_value_heads[name] = value
self.value_heads = nn.ModuleDict(_value_heads)
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]:
value_outputs = {}
for stream_name, head in self.value_heads.items():
value_outputs[stream_name] = head(hidden).squeeze(-1)
return value_outputs

206
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
import torch
from torch import nn
import numpy as np
import math
from mlagents.trainers.torch.layers import linear_layer, Initialization
EPSILON = 1e-7 # Small value to avoid divide by zero
class DistInstance(nn.Module, abc.ABC):
@abc.abstractmethod
def sample(self) -> torch.Tensor:
"""
Return a sample from this distribution.
"""
pass
@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the log probabilities of a particular value.
:param value: A value sampled from the distribution.
:returns: Log probabilities of the given value.
"""
pass
@abc.abstractmethod
def entropy(self) -> torch.Tensor:
"""
Returns the entropy of this distribution.
"""
pass
class DiscreteDistInstance(DistInstance):
@abc.abstractmethod
def all_log_prob(self) -> torch.Tensor:
"""
Returns the log probabilities of all actions represented by this distribution.
"""
pass
class GaussianDistInstance(DistInstance):
def __init__(self, mean, std):
super().__init__()
self.mean = mean
self.std = std
def sample(self):
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample
def log_prob(self, value):
var = self.std ** 2
log_scale = torch.log(self.std + EPSILON)
return (
-((value - self.mean) ** 2) / (2 * var + EPSILON)
- log_scale
- math.log(math.sqrt(2 * math.pi))
)
def pdf(self, value):
log_prob = self.log_prob(value)
return torch.exp(log_prob)
def entropy(self):
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
def sample(self):
unsquashed_sample = super().sample()
squashed = self.transform(unsquashed_sample)
return squashed
def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
def log_prob(self, value):
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
)
class CategoricalDistInstance(DiscreteDistInstance):
def __init__(self, logits):
super().__init__()
self.logits = logits
self.probs = torch.softmax(self.logits, dim=-1)
def sample(self):
return torch.multinomial(self.probs, 1)
def pdf(self, value):
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
# but torch.diag is not supported by ONNX export.
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
return torch.gather(
self.probs.permute(1, 0)[value.flatten().long()], -1, idx
).squeeze(-1)
def log_prob(self, value):
return torch.log(self.pdf(value))
def all_log_prob(self):
return torch.log(self.probs)
def entropy(self):
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
class GaussianDistribution(nn.Module):
def __init__(
self,
hidden_size: int,
num_outputs: int,
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__()
self.conditional_sigma = conditional_sigma
self.mu = linear_layer(
hidden_size,
num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
self.tanh_squash = tanh_squash
if conditional_sigma:
self.log_sigma = linear_layer(
hidden_size,
num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
else:
self.log_sigma = nn.Parameter(
torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else:
log_sigma = self.log_sigma
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
class MultiCategoricalDistribution(nn.Module):
def __init__(self, hidden_size: int, act_sizes: List[int]):
super().__init__()
self.act_sizes = act_sizes
self.branches = self._create_policy_branches(hidden_size)
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
branches = []
for size in self.act_sizes:
branch_output_layer = linear_layer(
hidden_size,
size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
branches.append(branch_output_layer)
return nn.ModuleList(branches)
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1)
normalized_logits = torch.log(normalized_probs + EPSILON)
return normalized_logits
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []
for idx, _ in enumerate(self.act_sizes):
start = int(np.sum(self.act_sizes[:idx]))
end = int(np.sum(self.act_sizes[: idx + 1]))
split_masks.append(masks[:, start:end])
return split_masks
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
# Todo - Support multiple branches in mask code
branch_distributions = []
masks = self._split_masks(masks)
for idx, branch in enumerate(self.branches):
logits = branch(inputs)
norm_logits = self._mask_branch(logits, masks[idx])
distribution = CategoricalDistInstance(norm_logits)
branch_distributions.append(distribution)
return branch_distributions

302
ml-agents/mlagents/trainers/torch/encoders.py


from typing import Tuple, Optional, Union
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish
import torch
from torch import nn
class Normalizer(nn.Module):
def __init__(self, vec_obs_size: int):
super().__init__()
self.register_buffer("normalization_steps", torch.tensor(1))
self.register_buffer("running_mean", torch.zeros(vec_obs_size))
self.register_buffer("running_variance", torch.ones(vec_obs_size))
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
normalized_state = torch.clamp(
(inputs - self.running_mean)
/ torch.sqrt(self.running_variance / self.normalization_steps),
-5,
5,
)
return normalized_state
def update(self, vector_input: torch.Tensor) -> None:
steps_increment = vector_input.size()[0]
total_new_steps = self.normalization_steps + steps_increment
input_to_old_mean = vector_input - self.running_mean
new_mean = self.running_mean + (input_to_old_mean / total_new_steps).sum(0)
input_to_new_mean = vector_input - new_mean
new_variance = self.running_variance + (
input_to_new_mean * input_to_old_mean
).sum(0)
# Update in-place
self.running_mean.data.copy_(new_mean.data)
self.running_variance.data.copy_(new_variance.data)
self.normalization_steps.data.copy_(total_new_steps.data)
def copy_from(self, other_normalizer: "Normalizer") -> None:
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)
self.running_mean.data.copy_(other_normalizer.running_mean.data)
self.running_variance.copy_(other_normalizer.running_variance.data)
def conv_output_shape(
h_w: Tuple[int, int],
kernel_size: Union[int, Tuple[int, int]] = 1,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
) -> Tuple[int, int]:
"""
Calculates the output shape (height and width) of the output of a convolution layer.
kernel_size, stride, padding and dilation correspond to the inputs of the
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
:param h_w: The height and width of the input.
:param kernel_size: The size of the kernel of the convolution (can be an int or a
tuple [width, height])
:param stride: The stride of the convolution
:param padding: The padding of the convolution
:param dilation: The dilation of the convolution
"""
from math import floor
if not isinstance(kernel_size, tuple):
kernel_size = (int(kernel_size), int(kernel_size))
h = floor(
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1
)
w = floor(
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1
)
return h, w
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]:
"""
Calculates the output shape (height and width) of the output of a max pooling layer.
kernel_size corresponds to the inputs of the
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html)
:param kernel_size: The size of the kernel of the convolution
"""
height = (h_w[0] - kernel_size) // 2 + 1
width = (h_w[1] - kernel_size) // 2 + 1
return height, width
class VectorEncoder(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
normalize: bool = False,
):
self.normalizer: Optional[Normalizer] = None
super().__init__()
self.layers = [
linear_layer(
input_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
]
self.layers.append(Swish())
if normalize:
self.normalizer = Normalizer(input_size)
for _ in range(num_layers - 1):
self.layers.append(
linear_layer(
hidden_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
)
self.layers.append(Swish())
self.seq_layers = nn.Sequential(*self.layers)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.normalizer is not None:
inputs = self.normalizer(inputs)
return self.seq_layers(inputs)
def copy_normalization(self, other_encoder: "VectorEncoder") -> None:
if self.normalizer is not None and other_encoder.normalizer is not None:
self.normalizer.copy_from(other_encoder.normalizer)
def update_normalization(self, inputs: torch.Tensor) -> None:
if self.normalizer is not None:
self.normalizer.update(inputs)
class VectorAndUnnormalizedInputEncoder(VectorEncoder):
"""
Encoder for concatenated vector input (can be normalized) and unnormalized vector input.
This is used for passing inputs to the network that should not be normalized, such as
actions in the case of a Q function or task parameterizations. It will result in an encoder with
this structure:
____________ ____________ ____________
| Vector | | Normalize | | Fully |
| | --> | | --> | Connected | ___________
|____________| |____________| | | | Output |
____________ | | --> | |
|Unnormalized| | | |___________|
| Input | ---------------------> | |
|____________| |____________|
"""
def __init__(
self,
input_size: int,
hidden_size: int,
unnormalized_input_size: int,
num_layers: int,
normalize: bool = False,
):
super().__init__(
input_size + unnormalized_input_size,
hidden_size,
num_layers,
normalize=False,
)
if normalize:
self.normalizer = Normalizer(input_size)
else:
self.normalizer = None
def forward( # pylint: disable=W0221
self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None
) -> None:
if unnormalized_inputs is None:
raise UnityTrainerException(
"Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input."
) # Fix mypy errors about method parameters.
if self.normalizer is not None:
inputs = self.normalizer(inputs)
return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1))
class SimpleVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.h_size = output_size
conv_1_hw = conv_output_shape((height, width), 8, 4)
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2)
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32
self.conv_layers = nn.Sequential(
nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]),
nn.LeakyReLU(),
nn.Conv2d(16, 32, [4, 4], [2, 2]),
nn.LeakyReLU(),
)
self.dense = nn.Sequential(
linear_layer(
self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
),
nn.LeakyReLU(),
)
def forward(self, visual_obs: torch.Tensor) -> None:
hidden = self.conv_layers(visual_obs)
hidden = torch.reshape(hidden, (-1, self.final_flat))
hidden = self.dense(hidden)
return hidden
class NatureVisualEncoder(nn.Module):
def __init__(self, height, width, initial_channels, output_size):
super().__init__()
self.h_size = output_size
conv_1_hw = conv_output_shape((height, width), 8, 4)
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2)
conv_3_hw = conv_output_shape(conv_2_hw, 3, 1)
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64
self.conv_layers = nn.Sequential(
nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]),
nn.LeakyReLU(),
nn.Conv2d(32, 64, [4, 4], [2, 2]),
nn.LeakyReLU(),
nn.Conv2d(64, 64, [3, 3], [1, 1]),
nn.LeakyReLU(),
)
self.dense = nn.Sequential(
linear_layer(
self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
),
nn.LeakyReLU(),
)
def forward(self, visual_obs: torch.Tensor) -> None:
hidden = self.conv_layers(visual_obs)
hidden = hidden.view([-1, self.final_flat])
hidden = self.dense(hidden)
return hidden
class ResNetBlock(nn.Module):
def __init__(self, channel: int):
"""
Creates a ResNet Block.
:param channel: The number of channels in the input (and output) tensors of the
convolutions
"""
super().__init__()
self.layers = nn.Sequential(
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return input_tensor + self.layers(input_tensor)
class ResNetVisualEncoder(nn.Module):
def __init__(self, height, width, initial_channels, final_hidden):
super().__init__()
n_channels = [16, 32, 32] # channel for each stack
n_blocks = 2 # number of residual blocks
self.layers = []
last_channel = initial_channels
for _, channel in enumerate(n_channels):
self.layers.append(
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)
)
self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
height, width = pool_out_shape((height, width), 3)
for _ in range(n_blocks):
self.layers.append(ResNetBlock(channel))
last_channel = channel
self.layers.append(Swish())
self.dense = linear_layer(
n_channels[-1] * height * width,
final_hidden,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
def forward(self, visual_obs):
batch_size = visual_obs.shape[0]
hidden = visual_obs
for layer in self.layers:
hidden = layer(hidden)
before_out = hidden.view(batch_size, -1)
return torch.relu(self.dense(before_out))

84
ml-agents/mlagents/trainers/torch/layers.py


import torch
from enum import Enum
class Swish(torch.nn.Module):
def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.mul(data, torch.sigmoid(data))
class Initialization(Enum):
Zero = 0
XavierGlorotNormal = 1
XavierGlorotUniform = 2
KaimingHeNormal = 3 # also known as Variance scaling
KaimingHeUniform = 4
_init_methods = {
Initialization.Zero: torch.zero_,
Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_,
Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_,
Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_,
Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_,
}
def linear_layer(
input_size: int,
output_size: int,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
kernel_gain: float = 1.0,
bias_init: Initialization = Initialization.Zero,
) -> torch.nn.Module:
"""
Creates a torch.nn.Linear module and initializes its weights.
:param input_size: The size of the input tensor
:param output_size: The size of the output tensor
:param kernel_init: The Initialization to use for the weights of the layer
:param kernel_gain: The multiplier for the weights of the kernel. Note that in
TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling
KaimingHeNormal with kernel_gain of 0.1
:param bias_init: The Initialization to use for the weights of the bias layer
"""
layer = torch.nn.Linear(input_size, output_size)
_init_methods[kernel_init](layer.weight.data)
layer.weight.data *= kernel_gain
_init_methods[bias_init](layer.bias.data)
return layer
def lstm_layer(
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
) -> torch.nn.Module:
"""
Creates a torch.nn.LSTM and initializes its weights and biases. Provides a
forget_bias offset like is done in TensorFlow.
"""
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
# Add forget_bias to forget gate bias
for name, param in lstm.named_parameters():
# Each weight and bias is a concatenation of 4 matrices
if "weight" in name:
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[kernel_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if "bias" in name:
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[bias_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if idx == 1:
param.data[idx * block_size : (idx + 1) * block_size].add_(
forget_bias
)
return lstm

493
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc
import torch
from torch import nn
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
DistInstance,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[
[torch.Tensor, int, ActivationFunction, int, str, bool], torch.Tensor
]
EPSILON = 1e-7
class NetworkBody(nn.Module):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders(
observation_shapes,
self.h_size,
network_settings.num_layers,
network_settings.vis_encode_type,
unnormalized_inputs=encoded_act_size,
normalize=self.normalize,
)
if self.use_lstm:
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
else:
self.lstm = None
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
vec_enc.update_normalization(vec_input)
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders):
n1.copy_normalization(n2)
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
vec_encodes = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if actions is not None:
hidden = encoder(vec_input, actions)
else:
hidden = encoder(vec_input)
vec_encodes.append(hidden)
vis_encodes = []
for idx, encoder in enumerate(self.visual_encoders):
vis_input = vis_inputs[idx]
vis_input = vis_input.permute([0, 3, 1, 2])
hidden = encoder(vis_input)
vis_encodes.append(hidden)
if len(vec_encodes) > 0 and len(vis_encodes) > 0:
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
encoding = torch.stack(
[vec_encodes_tensor, vis_encodes_tensor], dim=-1
).sum(dim=-1)
elif len(vec_encodes) > 0:
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
elif len(vis_encodes) > 0:
encoding = torch<