|
|
|
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
# NOTE: this class will be replaced with a multi-head attention when the time comes |
|
|
|
class MultiInputNetworkBody(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
all_net_inputs: List[List[torch.Tensor]], |
|
|
|
self_inputs: List[torch.Tensor], |
|
|
|
aux_inputs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
inputs = all_net_inputs[0] |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input = self_inputs[idx] |
|
|
|
all_net_inputs = [self_inputs] |
|
|
|
if aux_inputs is not None: |
|
|
|
all_net_inputs.extend(aux_inputs) |
|
|
|
# Get attention masks by grabbing an arbitrary obs across all the agents |
|
|
|
# Since these are raw obs, the padded values are still NaN |
|
|
|
only_first_obs = [_all_obs[0] for _all_obs in all_net_inputs] |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
inputs: List[List[torch.Tensor]], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
aux_inputs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
inputs, actions, memories, sequence_length |
|
|
|
inputs, actions, memories, sequence_length, aux_inputs |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
inputs, |
|
|
|
memories=critic_mem, |
|
|
|
aux_inputs=critic_obs, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
|
|
|
inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
inputs, |
|
|
|
memories=critic_mem, |
|
|
|
aux_inputs=critic_obs, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
return log_probs, entropies, value_outputs |
|
|
|
|
|
|
inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
action, log_probs, entropies = self.action_model(encoding, masks) |
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
|
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
inputs, |
|
|
|
memories=critic_mem, |
|
|
|
aux_inputs=critic_obs, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|