|
|
|
|
|
|
def to_numpy_dict(self) -> Dict[str, np.ndarray]: |
|
|
|
action_arrays_dict: Dict[str, np.ndarray] = {} |
|
|
|
if self.continuous is not None: |
|
|
|
action_arrays_dict["continuous_action"] = ModelUtils.to_numpy(self.continuous) |
|
|
|
action_arrays_dict["continuous_action"] = ModelUtils.to_numpy(self.continuous.unsqueeze(-1)[:, :, 0]) |
|
|
|
action_arrays_dict["discrete_action"] = np.array([ModelUtils.to_numpy(_disc) for _disc in self.discrete]) |
|
|
|
discrete_tensor = torch.stack(self.discrete, dim=-1) |
|
|
|
action_arrays_dict["discrete_action"] = ModelUtils.to_numpy(discrete_tensor[:, 0, :]) |
|
|
|
return action_arrays_dict |
|
|
|
|
|
|
|
def to_tensor_list(self) -> List[torch.Tensor]: |
|
|
|
|
|
|
if "continuous_action" in buff: |
|
|
|
continuous = ModelUtils.list_to_tensor(buff["continuous_action"]) |
|
|
|
if "discrete_action" in buff: |
|
|
|
discrete = ModelUtils.list_to_tensor(buff["discrete_action"]) |
|
|
|
discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_action"]) |
|
|
|
discrete = [discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])] |
|
|
|
return AgentAction(continuous, discrete) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
if self.continuous is not None: |
|
|
|
log_prob_arrays_dict["continuous_log_probs"] = ModelUtils.to_numpy(self.continuous) |
|
|
|
if self.discrete is not None: |
|
|
|
log_prob_arrays_dict["discrete_log_probs"] = np.array([ModelUtils.to_numpy(_disc) for _disc in self.discrete]) |
|
|
|
discrete_tensor = torch.stack(self.discrete, dim=-1) |
|
|
|
log_prob_arrays_dict["discrete_log_probs"] = ModelUtils.to_numpy(discrete_tensor.squeeze(1)) |
|
|
|
return log_prob_arrays_dict |
|
|
|
|
|
|
|
def to_tensor_list(self) -> List[torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_probs_and_entropy( |
|
|
|
agent_action: AgentAction, dists: List[DistInstance] |
|
|
|
action_list: List[torch.Tensor], dists: List[DistInstance] |
|
|
|
action_list = agent_action.to_tensor_list() |
|
|
|
for action, action_dist in zip(action_list, dists): |
|
|
|
log_prob = action_dist.log_prob(action) |
|
|
|
log_probs_list.append(log_prob) |
|
|
|