|
|
|
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
from typing import Callable, List, Dict, Tuple, Optional, Union |
|
|
|
import abc |
|
|
|
|
|
|
|
from mlagents.torch_utils import torch, nn |
|
|
|
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
) -> Tuple[Union[int, torch.Tensor], ...]: |
|
|
|
""" |
|
|
|
Forward pass of the Actor for inference. This is required for export to ONNX, and |
|
|
|
the inputs and outputs of this method should not be changed without a respective change |
|
|
|
|
|
|
super().__init__() |
|
|
|
self.action_spec = action_spec |
|
|
|
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|
|
|
self.is_continuous_int = torch.nn.Parameter( |
|
|
|
self.is_continuous_int_deprecated = torch.nn.Parameter( |
|
|
|
self.act_size_vector = torch.nn.Parameter( |
|
|
|
self.continuous_act_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False |
|
|
|
) |
|
|
|
# TODO: export list of branch sizes instead of sum |
|
|
|
self.discrete_act_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False |
|
|
|
) |
|
|
|
self.act_size_vector_deprecated = torch.nn.Parameter( |
|
|
|
torch.Tensor( |
|
|
|
[ |
|
|
|
self.action_spec.continuous_size |
|
|
|
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
) -> Tuple[Union[int, torch.Tensor], ...]: |
|
|
|
|
|
|
|
At this moment, torch.onnx.export() doesn't accept None as tensor to be exported, |
|
|
|
so the size of return tuple varies with action spec. |
|
|
|
# TODO: How this is written depends on how the inference model is structured |
|
|
|
action_out = self.action_model.get_action_out(encoding, masks) |
|
|
|
return ( |
|
|
|
action_out, |
|
|
|
cont_action_out, disc_action_out, action_out_deprecated = self.action_model.get_action_out( |
|
|
|
encoding, masks |
|
|
|
) |
|
|
|
export_out = [ |
|
|
|
self.is_continuous_int, |
|
|
|
self.act_size_vector, |
|
|
|
) |
|
|
|
] |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
export_out += [cont_action_out, self.continuous_act_size_vector] |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
export_out += [disc_action_out, self.discrete_act_size_vector] |
|
|
|
# Only export deprecated nodes with non-hybrid action spec |
|
|
|
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0: |
|
|
|
export_out += [ |
|
|
|
action_out_deprecated, |
|
|
|
self.is_continuous_int_deprecated, |
|
|
|
self.act_size_vector_deprecated, |
|
|
|
] |
|
|
|
return tuple(export_out) |
|
|
|
|
|
|
|
|
|
|
|
class SharedActorCritic(SimpleActor, ActorCritic): |
|
|
|