|
|
|
|
|
|
import enum |
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
import abc |
|
|
|
|
|
|
|
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
# NOTE: this class will be replaced with a multi-head attention when the time comes |
|
|
|
class MultiInputNetworkBody(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
encoded_act_size: int = 0, |
|
|
|
num_obs_heads: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
self.h_size = network_settings.hidden_units |
|
|
|
self.m_size = ( |
|
|
|
network_settings.memory.memory_size |
|
|
|
if network_settings.memory is not None |
|
|
|
else 0 |
|
|
|
) |
|
|
|
self.processors = [] |
|
|
|
encoder_input_size = 0 |
|
|
|
for i in range(num_obs_heads): |
|
|
|
_proc, _input_size = ModelUtils.create_input_processors( |
|
|
|
observation_shapes, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.processors.append(_proc) |
|
|
|
encoder_input_size += _input_size |
|
|
|
|
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = LSTM(self.h_size, self.m_size) |
|
|
|
else: |
|
|
|
self.lstm = None # type: ignore |
|
|
|
|
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
for _proc in self.processors: |
|
|
|
for _in, enc in zip(net_inputs, _proc): |
|
|
|
enc.update_normalization(_in) |
|
|
|
|
|
|
|
def copy_normalization(self, other_network: "NetworkBody") -> None: |
|
|
|
if self.normalize: |
|
|
|
for _proc in self.processors: |
|
|
|
for n1, n2 in zip(_proc, other_network.processors): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
all_net_inputs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
for net_inputs, processor_set in zip(all_net_inputs, self.processors): |
|
|
|
for idx, processor in enumerate(processor_set): |
|
|
|
net_input = net_inputs[idx] |
|
|
|
if not exporting_to_onnx.is_exporting() and len(net_input.shape) > 3: |
|
|
|
net_input = net_input.permute([0, 3, 1, 2]) |
|
|
|
processed_vec = processor(net_input) |
|
|
|
encodes.append(processed_vec) |
|
|
|
|
|
|
|
if len(encodes) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|
if actions is not None: |
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
else: |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
class ValueNetwork(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
|
class CentralizedValueNetwork(ValueNetwork): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
stream_names: List[str], |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
encoded_act_size: int = 0, |
|
|
|
outputs_per_stream: int = 1, |
|
|
|
num_agents: int = 1, |
|
|
|
): |
|
|
|
# This is not a typo, we want to call __init__ of nn.Module |
|
|
|
nn.Module.__init__(self) |
|
|
|
self.network_body = MultiInputNetworkBody( |
|
|
|
observation_shapes, |
|
|
|
network_settings, |
|
|
|
encoded_act_size=encoded_act_size, |
|
|
|
num_obs_heads=num_agents, |
|
|
|
) |
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
all_net_inputs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
all_net_inputs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
|
class Actor(abc.ABC): |
|
|
|
@abc.abstractmethod |
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
""" |
|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) |
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, observation_shapes, network_settings, num_agents=3 |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
net_inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks |
|
|
|
) |
|
|
|
|
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|