|
|
|
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.torch.attention import SmallestAttention, SimpleTransformer |
|
|
|
|
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
|
|
|
if network_settings.memory is not None |
|
|
|
else 0 |
|
|
|
) |
|
|
|
self.processors = [] |
|
|
|
encoder_input_size = 0 |
|
|
|
for i in range(num_obs_heads): |
|
|
|
_proc, _input_size = ModelUtils.create_input_processors( |
|
|
|
sensor_specs, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.processors.append(_proc) |
|
|
|
encoder_input_size += sum(_input_size) |
|
|
|
self.processors, _input_size = ModelUtils.create_input_processors( |
|
|
|
sensor_specs, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.transformer = SmallestAttention( |
|
|
|
sum(_input_size), [sum(_input_size)], self.h_size, self.h_size |
|
|
|
) |
|
|
|
encoder_input_size = self.h_size + sum(_input_size) |
|
|
|
|
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
|
|
|
else: |
|
|
|
self.lstm = None # type: ignore |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
for _proc in self.processors: |
|
|
|
for _in, enc in zip(obs, _proc): |
|
|
|
enc.update_normalization(_in) |
|
|
|
for vec_input, enc in zip(obs, self.processors): |
|
|
|
if isinstance(enc, VectorInput): |
|
|
|
enc.update_normalization(torch.as_tensor(vec_input)) |
|
|
|
for _proc in self.processors: |
|
|
|
for n1, n2 in zip(_proc, other_network.processors): |
|
|
|
for n1, n2 in zip(self.processors, other_network.processors): |
|
|
|
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput): |
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
all_net_inputs: List[List[torch.Tensor]], |
|
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
for inputs, processor_set in zip(all_net_inputs, self.processors): |
|
|
|
for idx, processor in enumerate(processor_set): |
|
|
|
concat_encoded_obs = [] |
|
|
|
x_self = None |
|
|
|
self_encodes = [] |
|
|
|
inputs = all_net_inputs[0] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
self_encodes.append(processed_obs) |
|
|
|
x_self = torch.cat(self_encodes, dim=-1) |
|
|
|
|
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_encoded_obs = [x_self] |
|
|
|
for inputs in all_net_inputs[1:]: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
if len(encodes) == 0: |
|
|
|
concat_entites = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
|
|
|
|
encoded_state = self.transformer( |
|
|
|
x_self, [concat_entites], SimpleTransformer.get_masks([concat_entites]) |
|
|
|
) |
|
|
|
|
|
|
|
if len(concat_encoded_obs) == 0: |
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
inputs = torch.cat([encoded_state, actions], dim=-1) |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
inputs = encoded_state |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
network_settings: NetworkSettings, |
|
|
|
encoded_act_size: int = 0, |
|
|
|
outputs_per_stream: int = 1, |
|
|
|
num_agents: int = 1, |
|
|
|
observation_shapes, |
|
|
|
network_settings, |
|
|
|
encoded_act_size=encoded_act_size, |
|
|
|
num_obs_heads=num_agents, |
|
|
|
observation_shapes, network_settings, encoded_act_size=encoded_act_size, |
|
|
|
) |
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, num_agents=2 |
|
|
|
stream_names, sensor_specs, network_settings |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
all_net_inputs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
return log_probs, entropies, value_outputs |
|
|
|