|
|
|
|
|
|
from mlagents_envs.base_env import BehaviorSpec |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.networks import NetworkBody |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Swish |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer |
|
|
|
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._action_flattener = ModelUtils.ActionFlattener(specs) |
|
|
|
|
|
|
|
self.inverse_model_action_predition = torch.nn.Sequential( |
|
|
|
linear_layer(2 * settings.encoding_size, 256), |
|
|
|
Swish(), |
|
|
|
self.inverse_model_action_prediction = torch.nn.Sequential( |
|
|
|
LinearEncoder(2 * settings.encoding_size, 1, 256), |
|
|
|
linear_layer( |
|
|
|
settings.encoding_size + self._action_flattener.flattened_size, 256 |
|
|
|
LinearEncoder( |
|
|
|
settings.encoding_size + self._action_flattener.flattened_size, 1, 256 |
|
|
|
Swish(), |
|
|
|
linear_layer(256, settings.encoding_size), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
inverse_model_input = torch.cat( |
|
|
|
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1 |
|
|
|
) |
|
|
|
hidden = self.inverse_model_action_predition(inverse_model_input) |
|
|
|
hidden = self.inverse_model_action_prediction(inverse_model_input) |
|
|
|
if self._policy_specs.is_action_continuous(): |
|
|
|
return hidden |
|
|
|
else: |
|
|
|