|
|
|
|
|
|
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor) |
|
|
|
return attn_mask |
|
|
|
|
|
|
|
def forward( |
|
|
|
def q_net( |
|
|
|
value_inputs: List[List[torch.Tensor]], |
|
|
|
q_inputs: List[List[torch.Tensor]], |
|
|
|
q_actions: List[AgentAction], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
# Tensors that go into ResidualSelfAttention |
|
|
|
self_attn_inputs = [] |
|
|
|
|
|
|
|
concat_f_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_enc_q_obs = [] |
|
|
|
for inputs, actions in zip(q_inputs, q_actions): |
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
encoding, memories = self.forward( |
|
|
|
f_inp, |
|
|
|
None, |
|
|
|
self_attn_masks, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
def baseline( |
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
self_attn_masks = [] |
|
|
|
|
|
|
|
f_inp = None |
|
|
|
concat_f_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
actions.to_flat(self.action_spec.discrete_branches), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
if concat_enc_q_obs: |
|
|
|
q_input_concat = torch.stack(concat_enc_q_obs, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(q_inputs)) |
|
|
|
encoded_obs_action = self.obs_action_encoder(None, q_input_concat) |
|
|
|
self_attn_inputs.append(encoded_obs_action) |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
|
|
|
|
if concat_f_inp: |
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
for inputs in value_inputs: |
|
|
|
for inputs in self_obs: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
if concat_encoded_obs: |
|
|
|
value_input_concat = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
# Get the mask from nans |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(value_inputs)) |
|
|
|
encoded_obs = self.obs_encoder(None, value_input_concat) |
|
|
|
self_attn_inputs.append(encoded_obs) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
# Get the mask from nans |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(self_obs)) |
|
|
|
encoding, memories = self.forward( |
|
|
|
f_inp, |
|
|
|
g_inp, |
|
|
|
self_attn_masks, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
f_enc: torch.Tensor, |
|
|
|
g_enc: torch.Tensor, |
|
|
|
self_attn_masks: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
self_attn_inputs = [] |
|
|
|
|
|
|
|
if f_enc is not None: |
|
|
|
self_attn_inputs.append(self.obs_action_encoder(None, f_enc)) |
|
|
|
if g_enc is not None: |
|
|
|
self_attn_inputs.append(self.obs_encoder(None, g_enc)) |
|
|
|
|
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
encoded_state = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
def q_net( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body.q_net( |
|
|
|
obs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
def baseline( |
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body.baseline( |
|
|
|
self_obs, obs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
actions: AgentAction, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
team_obs: List[List[torch.Tensor]] = None, |
|
|
|
team_act: List[AgentAction] = None, |
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None and critic_obs: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
mar_value_outputs, _ = self.critic( |
|
|
|
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
critic_obs, |
|
|
|
[inputs], |
|
|
|
[actions], |
|
|
|
|
|
|
|
all_obs = [inputs] |
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
all_acts = [actions] |
|
|
|
if team_act is not None and team_act: |
|
|
|
all_acts.extend(team_act) |
|
|
|
|
|
|
|
baseline_outputs, _ = self.critic.baseline( |
|
|
|
inputs, |
|
|
|
team_obs, |
|
|
|
team_act, |
|
|
|
if mar_value_outputs is None: |
|
|
|
mar_value_outputs = value_outputs |
|
|
|
|
|
|
|
value_outputs, critic_mem_out = self.critic.q_net( |
|
|
|
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
# mar_value_outputs = value_outputs |
|
|
|
|
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
|
|
|
return value_outputs, mar_value_outputs, memories_out |
|
|
|
return value_outputs, baseline_outputs, memories_out |
|
|
|
|
|
|
|
def get_stats_and_value( |
|
|
|
self, |
|
|
|