|
|
|
|
|
|
) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
f_enc: torch.Tensor, |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body.value( |
|
|
|
obs, memories, sequence_length |
|
|
|
) |
|
|
|
encoding, memories = self.network_body.value(obs, memories, sequence_length) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
self.target = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.target.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
value_outputs, critic_mem_out = self.critic.value( |
|
|
|
all_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.critic.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
value_outputs, critic_mem_out = self.critic.value( |
|
|
|
all_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
|
|
|
if team_act is not None and team_act: |
|
|
|
all_acts.extend(team_act) |
|
|
|
|
|
|
|
baseline_outputs, _ = self.target.baseline( |
|
|
|
baseline_outputs, _ = self.critic.baseline( |
|
|
|
inputs, |
|
|
|
team_obs, |
|
|
|
team_act, |
|
|
|
|
|
|
|
|
|
|
value_outputs, critic_mem_out = self.target.q_net( |
|
|
|
value_outputs, critic_mem_out = self.critic.q_net( |
|
|
|
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
team_obs=team_obs, |
|
|
|
team_act=team_act, |
|
|
|
) |
|
|
|
value_outputs, _ = self.target_critic_value(inputs, memories=critic_mem, sequence_length=sequence_length, team_obs=team_obs) |
|
|
|
value_outputs, _ = self.target_critic_value( |
|
|
|
inputs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
team_obs=team_obs, |
|
|
|
) |
|
|
|
|
|
|
|
return log_probs, entropies, q_outputs, baseline_outputs, value_outputs |
|
|
|
|
|
|
|