|
|
|
|
|
|
from typing import List, Optional, Tuple |
|
|
|
from typing import List, Optional, Tuple, NamedTuple, Dict |
|
|
|
from mlagents.torch_utils import torch, nn |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance |
|
|
|
|
|
|
|
|
|
|
|
class AgentAction(NamedTuple): |
|
|
|
""" |
|
|
|
A NamedTuple containing the tensor for continuous actions and list of tensors for |
|
|
|
discrete actions. Utility functions provide numpy <=> tensor conversions to be |
|
|
|
sent as actions to the environment manager as well as used by the optimizers. |
|
|
|
:param continuous_tensor: Torch tensor corresponding to continuous actions |
|
|
|
:param discrete_list: List of Torch tensors each corresponding to discrete actions |
|
|
|
""" |
|
|
|
|
|
|
|
continuous_tensor: torch.Tensor |
|
|
|
discrete_list: List[torch.Tensor] |
|
|
|
|
|
|
|
@property |
|
|
|
def discrete_tensor(self): |
|
|
|
""" |
|
|
|
Returns the discrete action list as a stacked tensor |
|
|
|
""" |
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
|
|
|
|
def to_numpy_dict(self) -> Dict[str, np.ndarray]: |
|
|
|
""" |
|
|
|
Returns a Dict of np arrays with an entry correspinding to the continuous action |
|
|
|
and an entry corresponding to the discrete action. "continuous_action" and |
|
|
|
"discrete_action" are added to the agents buffer individually to maintain a flat buffer. |
|
|
|
""" |
|
|
|
array_dict: Dict[str, np.ndarray] = {} |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
array_dict["continuous_action"] = ModelUtils.to_numpy( |
|
|
|
self.continuous_tensor |
|
|
|
) |
|
|
|
if self.discrete_list is not None: |
|
|
|
array_dict["discrete_action"] = ModelUtils.to_numpy( |
|
|
|
self.discrete_tensor[:, 0, :] |
|
|
|
) |
|
|
|
return array_dict |
|
|
|
|
|
|
|
def to_tensor_list(self) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Returns the tensors in the AgentAction as a flat List of torch Tensors. This will be removed |
|
|
|
when the ActionModel is merged. |
|
|
|
""" |
|
|
|
tensor_list: List[torch.Tensor] = [] |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
tensor_list.append(self.continuous_tensor) |
|
|
|
if self.discrete_list is not None: |
|
|
|
tensor_list += ( |
|
|
|
self.discrete_list |
|
|
|
) # Note this is different for ActionLogProbs |
|
|
|
return tensor_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create( |
|
|
|
tensor_list: List[torch.Tensor], action_spec: ActionSpec |
|
|
|
) -> "AgentAction": |
|
|
|
""" |
|
|
|
A static method that converts a list of torch Tensors into an AgentAction using the ActionSpec. |
|
|
|
This will change (and may be removed) in the ActionModel. |
|
|
|
""" |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
_offset = 0 |
|
|
|
if action_spec.continuous_size > 0: |
|
|
|
continuous = tensor_list[0] |
|
|
|
_offset = 1 |
|
|
|
if action_spec.discrete_size > 0: |
|
|
|
discrete = tensor_list[_offset:] |
|
|
|
return AgentAction(continuous, discrete) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction": |
|
|
|
""" |
|
|
|
A static method that accesses continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
|
|
""" |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
if "continuous_action" in buff: |
|
|
|
continuous = ModelUtils.list_to_tensor(buff["continuous_action"]) |
|
|
|
if "discrete_action" in buff: |
|
|
|
discrete_tensor = ModelUtils.list_to_tensor( |
|
|
|
buff["discrete_action"], dtype=torch.long |
|
|
|
) |
|
|
|
discrete = [ |
|
|
|
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|
|
|
] |
|
|
|
return AgentAction(continuous, discrete) |
|
|
|
|
|
|
|
|
|
|
|
class ActionLogProbs(NamedTuple): |
|
|
|
""" |
|
|
|
A NamedTuple containing the tensor for continuous log probs and list of tensors for |
|
|
|
discrete log probs of individual actions as well as all the log probs for an entire branch. |
|
|
|
Utility functions provide numpy <=> tensor conversions to be used by the optimizers. |
|
|
|
:param continuous_tensor: Torch tensor corresponding to log probs of continuous actions |
|
|
|
:param discrete_list: List of Torch tensors each corresponding to log probs of the discrete actions that were |
|
|
|
sampled. |
|
|
|
:param all_discrete_list: List of Torch tensors each corresponding to all log probs of |
|
|
|
a discrete action branch, even the discrete actions that were not sampled. all_discrete_list is a list of Tensors, |
|
|
|
each Tensor corresponds to one discrete branch log probabilities. |
|
|
|
""" |
|
|
|
|
|
|
|
continuous_tensor: torch.Tensor |
|
|
|
discrete_list: List[torch.Tensor] |
|
|
|
all_discrete_list: Optional[List[torch.Tensor]] |
|
|
|
|
|
|
|
@property |
|
|
|
def discrete_tensor(self): |
|
|
|
""" |
|
|
|
Returns the discrete log probs list as a stacked tensor |
|
|
|
""" |
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
|
|
|
|
@property |
|
|
|
def all_discrete_tensor(self): |
|
|
|
""" |
|
|
|
Returns the discrete log probs of each branch as a tensor |
|
|
|
""" |
|
|
|
return torch.cat(self.all_discrete_list, dim=1) |
|
|
|
|
|
|
|
def to_numpy_dict(self) -> Dict[str, np.ndarray]: |
|
|
|
""" |
|
|
|
Returns a Dict of np arrays with an entry correspinding to the continuous log probs |
|
|
|
and an entry corresponding to the discrete log probs. "continuous_log_probs" and |
|
|
|
"discrete_log_probs" are added to the agents buffer individually to maintain a flat buffer. |
|
|
|
""" |
|
|
|
array_dict: Dict[str, np.ndarray] = {} |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
array_dict["continuous_log_probs"] = ModelUtils.to_numpy( |
|
|
|
self.continuous_tensor |
|
|
|
) |
|
|
|
if self.discrete_list is not None: |
|
|
|
|
|
|
|
array_dict["discrete_log_probs"] = ModelUtils.to_numpy(self.discrete_tensor) |
|
|
|
return array_dict |
|
|
|
|
|
|
|
def _to_tensor_list(self) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This |
|
|
|
is private and serves as a utility for self.flatten() |
|
|
|
""" |
|
|
|
tensor_list: List[torch.Tensor] = [] |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
tensor_list.append(self.continuous_tensor) |
|
|
|
if self.discrete_list is not None: |
|
|
|
tensor_list.append( |
|
|
|
self.discrete_tensor |
|
|
|
) # Note this is different for AgentActions |
|
|
|
return tensor_list |
|
|
|
|
|
|
|
def flatten(self) -> torch.Tensor: |
|
|
|
""" |
|
|
|
A utility method that returns all log probs in ActionLogProbs as a flattened tensor. |
|
|
|
This is useful for algorithms like PPO which can treat all log probs in the same way. |
|
|
|
""" |
|
|
|
return torch.cat(self._to_tensor_list(), dim=1) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create( |
|
|
|
log_prob_list: List[torch.Tensor], |
|
|
|
action_spec: ActionSpec, |
|
|
|
all_log_prob_list: List[torch.Tensor] = None, |
|
|
|
) -> "ActionLogProbs": |
|
|
|
""" |
|
|
|
A static method that converts a list of torch Tensors into an ActionLogProbs using the ActionSpec. |
|
|
|
This will change (and may be removed) in the ActionModel. |
|
|
|
""" |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
_offset = 0 |
|
|
|
if action_spec.continuous_size > 0: |
|
|
|
continuous = log_prob_list[0] |
|
|
|
_offset = 1 |
|
|
|
if action_spec.discrete_size > 0: |
|
|
|
discrete = log_prob_list[_offset:] |
|
|
|
return ActionLogProbs(continuous, discrete, all_log_prob_list) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": |
|
|
|
""" |
|
|
|
A static method that accesses continuous and discrete log probs fields in an AgentBuffer |
|
|
|
and constructs the corresponding ActionLogProbs from the retrieved np arrays. |
|
|
|
""" |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
|
|
|
|
if "continuous_log_probs" in buff: |
|
|
|
continuous = ModelUtils.list_to_tensor(buff["continuous_log_probs"]) |
|
|
|
if "discrete_log_probs" in buff: |
|
|
|
discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_log_probs"]) |
|
|
|
discrete = [ |
|
|
|
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|
|
|
] |
|
|
|
return ActionLogProbs(continuous, discrete, None) |
|
|
|
|
|
|
|
|
|
|
|
class ModelUtils: |
|
|
|
# Minimum supported side for each encoder type. If refactoring an encoder, please |
|
|
|
# adjust these also. |
|
|
|
|
|
|
else: |
|
|
|
return sum(self._specs.discrete_branches) |
|
|
|
|
|
|
|
def forward(self, action: torch.Tensor) -> torch.Tensor: |
|
|
|
def forward(self, action: AgentAction) -> torch.Tensor: |
|
|
|
return action |
|
|
|
return action.continuous_tensor |
|
|
|
torch.as_tensor(action, dtype=torch.long), |
|
|
|
torch.as_tensor(action.discrete_tensor, dtype=torch.long), |
|
|
|
self._specs.discrete_branches, |
|
|
|
), |
|
|
|
dim=1, |
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_probs_and_entropy( |
|
|
|
action_list: List[torch.Tensor], dists: List[DistInstance] |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
) -> Tuple[List[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
log_probs_list = [] |
|
|
|
all_probs_list = [] |
|
|
|
entropies_list = [] |
|
|
|
|
|
|
entropies_list.append(action_dist.entropy()) |
|
|
|
if isinstance(action_dist, DiscreteDistInstance): |
|
|
|
all_probs_list.append(action_dist.all_log_prob()) |
|
|
|
log_probs = torch.stack(log_probs_list, dim=-1) |
|
|
|
log_probs = log_probs.squeeze(-1) |
|
|
|
all_probs = None |
|
|
|
else: |
|
|
|
all_probs = torch.cat(all_probs_list, dim=-1) |
|
|
|
return log_probs, entropies, all_probs |
|
|
|
return log_probs_list, entropies, all_probs_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
|