|
|
|
|
|
|
import enum |
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
import abc |
|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
network_settings.vis_encode_type, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.observation_shapes = observation_shapes |
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
|
|
|
dists, _ = self.get_dists(net_inputs, masks, memories, 1) |
|
|
|
concatenated_vec_obs = vec_inputs[0] |
|
|
|
inputs = [] |
|
|
|
start = 0 |
|
|
|
end = 0 |
|
|
|
vis_index = 0 |
|
|
|
for i, enc in enumerate(self.network_body.processors): |
|
|
|
if isinstance(enc, VectorInput): |
|
|
|
# This is a vec_obs |
|
|
|
vec_size = self.network_body.observation_shapes[i][0] |
|
|
|
end = start + vec_size |
|
|
|
inputs.append(concatenated_vec_obs[:, start:end]) |
|
|
|
start = end |
|
|
|
else: |
|
|
|
inputs.append(vis_inputs[vis_index]) |
|
|
|
vis_index += 1 |
|
|
|
dists, _ = self.get_dists(inputs, masks, memories, 1) |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
action_out = torch.stack(action_list, dim=-1) |
|
|
|