|
|
|
|
|
|
) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
f_enc: torch.Tensor, |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body.value( |
|
|
|
obs, memories, sequence_length |
|
|
|
) |
|
|
|
encoding, memories = self.network_body.value(obs, memories, sequence_length) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.target.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
value_outputs, critic_mem_out = self.target.value( |
|
|
|
all_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.critic.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
value_outputs, critic_mem_out = self.critic.value( |
|
|
|
all_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
|
|
|
sequence_length: int = 1, |
|
|
|
team_obs: List[List[torch.Tensor]] = None, |
|
|
|
team_act: List[AgentAction] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
|
|
|
|
all_obs = [inputs] |
|
|
|
|
|
|
if team_act is not None and team_act: |
|
|
|
all_acts.extend(team_act) |
|
|
|
|
|
|
|
baseline_outputs, _ = self.critic.baseline( |
|
|
|
baseline_outputs, critic_mem_out = self.critic.baseline( |
|
|
|
inputs, |
|
|
|
team_obs, |
|
|
|
team_act, |
|
|
|
|
|
|
|
|
|
|
q_out, critic_mem_out = self.critic.q_net( |
|
|
|
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
# q_out, critic_mem_out = self.critic.q_net( |
|
|
|
# all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
# ) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
# mar_value_outputs = value_outputs |
|
|
|
|
|
|
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|
|
|
else: |
|
|
|
memories_out = None |
|
|
|
return q_out, baseline_outputs, memories_out |
|
|
|
return baseline_outputs, memories_out |
|
|
|
|
|
|
|
def get_stats_and_value( |
|
|
|
self, |
|
|
|
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
|
|
|
|
q_outputs, baseline_outputs, _ = self.critic_pass( |
|
|
|
baseline_outputs, _ = self.critic_pass( |
|
|
|
inputs, |
|
|
|
actions, |
|
|
|
memories=critic_mem, |
|
|
|
|
|
|
) |
|
|
|
value_outputs, _ = self.target_critic_value(inputs, memories=critic_mem, sequence_length=sequence_length, team_obs=team_obs) |
|
|
|
value_outputs, _ = self.target_critic_value( |
|
|
|
inputs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
team_obs=team_obs, |
|
|
|
) |
|
|
|
return log_probs, entropies, q_outputs, baseline_outputs, value_outputs |
|
|
|
return log_probs, entropies, baseline_outputs, value_outputs |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|