|
|
|
|
|
|
from typing import Callable, NamedTuple, List, Dict, Tuple |
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
|
|
|
|
|
class NormalizerTensors(NamedTuple): |
|
|
|
steps: torch.Tensor |
|
|
|
running_mean: torch.Tensor |
|
|
|
running_variance: torch.Tensor |
|
|
|
|
|
|
|
|
|
|
|
encoded_act_size: int = 0, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
|
|
|
else 0 |
|
|
|
) |
|
|
|
|
|
|
|
( |
|
|
|
self.visual_encoders, |
|
|
|
self.vector_encoders, |
|
|
|
self.vector_normalizers, |
|
|
|
) = ModelUtils.create_encoders( |
|
|
|
self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders( |
|
|
|
action_size=0, |
|
|
|
unnormalized_inputs=encoded_act_size, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
|
|
|
|
def update_normalization(self, vec_inputs): |
|
|
|
if self.normalize: |
|
|
|
for idx, vec_input in enumerate(vec_inputs): |
|
|
|
self.vector_normalizers[idx].update(vec_input) |
|
|
|
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): |
|
|
|
vec_enc.update_normalization(vec_input) |
|
|
|
for n1, n2 in zip( |
|
|
|
self.vector_normalizers, other_network.vector_normalizers |
|
|
|
): |
|
|
|
n1.copy_from(n2) |
|
|
|
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1): |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: torch.Tensor, |
|
|
|
vis_inputs: torch.Tensor, |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if self.normalize: |
|
|
|
vec_input = self.vector_normalizers[idx](vec_input) |
|
|
|
hidden = encoder(vec_input) |
|
|
|
if actions is not None: |
|
|
|
hidden = encoder(vec_input, actions) |
|
|
|
else: |
|
|
|
hidden = encoder(vec_input) |
|
|
|
vec_embeds.append(hidden) |
|
|
|
|
|
|
|
vis_embeds = [] |
|
|
|
|
|
|
return embedding, memories |
|
|
|
|
|
|
|
|
|
|
|
class QNetwork(NetworkBody): |
|
|
|
def __init__( # pylint: disable=W0231 |
|
|
|
class ValueNetwork(nn.Module): |
|
|
|
def __init__( |
|
|
|
act_type: ActionType, |
|
|
|
act_size: List[int], |
|
|
|
encoded_act_size: int = 0, |
|
|
|
outputs_per_stream: int = 1, |
|
|
|
|
|
|
|
self.normalize = network_settings.normalize |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
self.h_size = network_settings.hidden_units |
|
|
|
self.m_size = ( |
|
|
|
network_settings.memory.memory_size |
|
|
|
if network_settings.memory is not None |
|
|
|
else 0 |
|
|
|
self.network_body = NetworkBody( |
|
|
|
observation_shapes, network_settings, encoded_act_size=encoded_act_size |
|
|
|
|
|
|
|
( |
|
|
|
self.visual_encoders, |
|
|
|
self.vector_encoders, |
|
|
|
self.vector_normalizers, |
|
|
|
) = ModelUtils.create_encoders( |
|
|
|
observation_shapes, |
|
|
|
self.h_size, |
|
|
|
network_settings.num_layers, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
action_size=sum(act_size) if act_type == ActionType.CONTINUOUS else 0, |
|
|
|
self.value_heads = ValueHeads( |
|
|
|
stream_names, network_settings.hidden_units, outputs_per_stream |
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) |
|
|
|
else: |
|
|
|
self.lstm = None |
|
|
|
if act_type == ActionType.DISCRETE: |
|
|
|
self.q_heads = ValueHeads( |
|
|
|
stream_names, network_settings.hidden_units, sum(act_size) |
|
|
|
) |
|
|
|
else: |
|
|
|
self.q_heads = ValueHeads(stream_names, network_settings.hidden_units) |
|
|
|
|
|
|
|
def forward( # pylint: disable=W0221 |
|
|
|
def forward( |
|
|
|
memories: torch.Tensor = None, |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
actions: torch.Tensor = None, |
|
|
|
vec_embeds = [] |
|
|
|
for i, (enc, norm) in enumerate( |
|
|
|
zip(self.vector_encoders, self.vector_normalizers) |
|
|
|
): |
|
|
|
vec_input = vec_inputs[i] |
|
|
|
if self.normalize: |
|
|
|
vec_input = norm(vec_input) |
|
|
|
if actions is not None: |
|
|
|
hidden = enc(torch.cat([vec_input, actions], dim=-1)) |
|
|
|
else: |
|
|
|
hidden = enc(vec_input) |
|
|
|
vec_embeds.append(hidden) |
|
|
|
|
|
|
|
vis_embeds = [] |
|
|
|
for idx, encoder in enumerate(self.visual_encoders): |
|
|
|
vis_input = vis_inputs[idx] |
|
|
|
vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
hidden = encoder(vis_input) |
|
|
|
vis_embeds.append(hidden) |
|
|
|
|
|
|
|
# embedding = vec_embeds[0] |
|
|
|
if len(vec_embeds) > 0 and len(vis_embeds) > 0: |
|
|
|
vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1) |
|
|
|
vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1) |
|
|
|
embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum( |
|
|
|
dim=-1 |
|
|
|
) |
|
|
|
elif len(vec_embeds) > 0: |
|
|
|
embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1) |
|
|
|
elif len(vis_embeds) > 0: |
|
|
|
embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1) |
|
|
|
else: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
if self.lstm is not None: |
|
|
|
embedding = embedding.view([sequence_length, -1, self.h_size]) |
|
|
|
memories_tensor = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
embedding, memories = self.lstm(embedding, memories_tensor) |
|
|
|
embedding = embedding.view([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories_tensor, dim=-1) |
|
|
|
|
|
|
|
output, _ = self.q_heads(embedding) |
|
|
|
embedding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output, _ = self.value_heads(embedding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
self.distribution = MultiCategoricalDistribution(embedding_size, act_size) |
|
|
|
if separate_critic: |
|
|
|
self.critic = Critic(stream_names, observation_shapes, network_settings) |
|
|
|
self.critic = ValueNetwork( |
|
|
|
stream_names, observation_shapes, network_settings |
|
|
|
) |
|
|
|
else: |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, embedding_size) |
|
|
|
|
|
|
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1 |
|
|
|
): |
|
|
|
embedding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories, sequence_length |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
dists = self.distribution(embedding) |
|
|
|
|
|
|
self, vec_inputs, vis_inputs=None, masks=None, memories=None, sequence_length=1 |
|
|
|
): |
|
|
|
embedding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories, sequence_length |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
dists, value_outputs, memories = self.get_dist_and_value( |
|
|
|
vec_inputs, vis_inputs, masks, memories, sequence_length |
|
|
|
|
|
|
self.is_continuous_int, |
|
|
|
self.act_size_vector, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
stream_names: List[str], |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.network_body = NetworkBody(observation_shapes, network_settings) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, network_settings.hidden_units) |
|
|
|
|
|
|
|
def forward(self, vec_inputs, vis_inputs): |
|
|
|
embedding, _ = self.network_body(vec_inputs, vis_inputs) |
|
|
|
return self.value_heads(embedding) |
|
|
|
|
|
|
|
|
|
|
|
class GlobalSteps(nn.Module): |
|
|
|