|
|
|
|
|
|
self.normalize = network_settings.normalize |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
# Scale network depending on num agents |
|
|
|
self.h_size = network_settings.hidden_units * num_obs_heads |
|
|
|
self.h_size = network_settings.hidden_units |
|
|
|
self.m_size = ( |
|
|
|
network_settings.memory.memory_size |
|
|
|
if network_settings.memory is not None |
|
|
|
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.processors.append(_proc) |
|
|
|
encoder_input_size += _input_size |
|
|
|
encoder_input_size += sum(_input_size) |
|
|
|
|
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units * num_agents |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
def forward( |
|
|
|
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
inputs, memories=memories, sequence_length=sequence_length, |
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
value_outputs = self.value_heads(encoding) |
|
|
|
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
all_net_inputs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
return log_probs, entropies, value_outputs |
|
|
|