|
|
|
|
|
|
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
continuous_actions, discrete_actions= torch.split(actions, self.continuous_act_size, dim=1) |
|
|
|
continuous_actions, discrete_actions = torch.split(actions, [self.continuous_act_size, len(self.discrete_act_size)], dim=1) |
|
|
|
|
|
|
|
discrete_action_list = [discrete_actions[..., i] for i in range(discrete_actions.shape[-1])] |
|
|
|
discrete_action_list = [discrete_actions[:, i] for i in range(len(self.discrete_act_size))] |
|
|
|
log_probs = torch.add(continuous_log_probs, discrete_log_probs) |
|
|
|
entropies = torch.add(continuous_entropies, discrete_entropies) |
|
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1) |
|
|
|
entropies = torch.cat([continuous_entropies, torch.mean(discrete_entropies, dim=0).unsqueeze(0)], dim=1) |
|
|
|
return log_probs, entropies |
|
|
|
|
|
|
|
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
) |
|
|
|
continuous_actions = torch.stack(continuous_action_list, dim=-1) |
|
|
|
continuous_actions = continuous_actions[:, :, 0] |
|
|
|
|
|
|
|
|
|
|
|
discrete_action_list = self._sample_action(discrete_dists) |
|
|
|
discrete_entropies, discrete_log_probs, discrete_all_probs = ModelUtils.get_probs_and_entropy( |
|
|
|
|
|
|
discrete_actions = discrete_actions[:, 0, :] |
|
|
|
|
|
|
|
action = torch.cat([continuous_actions, discrete_actions.type(torch.float)], axis=1) |
|
|
|
log_probs = torch.add(continuous_log_probs, discrete_log_probs) |
|
|
|
entropies = torch.add(continuous_entropies, discrete_entropies) |
|
|
|
|
|
|
|
#print("ac",action) |
|
|
|
#print("clp",continuous_log_probs) |
|
|
|
#print("dlp",discrete_log_probs) |
|
|
|
#print("lp",log_probs) |
|
|
|
#print("en",entropies) |
|
|
|
action = torch.cat([continuous_actions, discrete_actions.type(torch.float)], dim=1) |
|
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1) |
|
|
|
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1) |
|
|
|
return (action, log_probs, entropies) |