|
|
|
|
|
|
|
|
|
|
from mlagents.torch_utils import torch, nn |
|
|
|
|
|
|
|
from mlagents_envs.base_env import ActionType, ActionSpec |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents.trainers.torch.distributions import ( |
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.action_spec = action_spec |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
self.act_type = ActionType.CONTINUOUS |
|
|
|
else: |
|
|
|
self.act_type = ActionType.DISCRETE |
|
|
|
torch.Tensor([int(self.act_type == ActionType.CONTINUOUS)]) |
|
|
|
torch.Tensor([int(self.action_spec.is_continuous())]) |
|
|
|
) |
|
|
|
self.act_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([self.action_spec.total_size]), requires_grad=False |
|
|
|
|
|
|
else: |
|
|
|
self.encoding_size = network_settings.hidden_units |
|
|
|
|
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
self.distribution = GaussianDistribution( |
|
|
|
self.encoding_size, |
|
|
|
self.action_spec.continuous_size, |
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
dists = self.distribution(encoding) |
|
|
|
else: |
|
|
|
dists = self.distribution(encoding, masks) |
|
|
|
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
action_out = torch.stack(action_list, dim=-1) |
|
|
|
else: |
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
dists = self.distribution(encoding) |
|
|
|
else: |
|
|
|
dists = self.distribution(encoding, masks=masks) |
|
|
|