|
|
|
|
|
|
for split_action in split_actions: |
|
|
|
action_list = [split_action[..., i] for i in range(split_action.shape[-1])] |
|
|
|
action_lists += action_list |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) #self._get_stats(actions, dists) |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) |
|
|
|
return log_probs, entropies |
|
|
|
|
|
|
|
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
action_outs : List[torch.Tensor] = [] |
|
|
|
action_lists = self._sample_action(dists) |
|
|
|
for action_list, dist in zip(action_lists, dists): |
|
|
|
action_out = action_list.unsqueeze(-1)#torch.stack(action_list, dim=-1) |
|
|
|
action_out = action_list.unsqueeze(-1) |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) #self._get_stats(actions, dists)self._get_stats(action_lists, dists) |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) |
|
|
|
action = torch.cat(action_outs, dim=1) |
|
|
|
return (action, log_probs, entropies) |