|
|
|
|
|
|
from typing import Optional, Dict |
|
|
|
from typing import Optional, Dict, List, Tuple |
|
|
|
import numpy as np |
|
|
|
from mlagents.torch_utils import torch, default_device |
|
|
|
|
|
|
|
|
|
|
from mlagents_envs.base_env import BehaviorSpec |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.networks import NetworkBody |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Swish, Initialization |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization |
|
|
|
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|
|
|
from mlagents.trainers.demo_loader import demo_to_buffer |
|
|
|
|
|
|
|
|
|
|
self._use_vail = settings.use_vail |
|
|
|
self._settings = settings |
|
|
|
|
|
|
|
state_encoder_settings = NetworkSettings( |
|
|
|
encoder_settings = NetworkSettings( |
|
|
|
normalize=False, |
|
|
|
hidden_units=settings.encoding_size, |
|
|
|
num_layers=2, |
|
|
|
|
|
|
self._state_encoder = NetworkBody( |
|
|
|
specs.observation_shapes, state_encoder_settings |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
encoder_input_size = settings.encoding_size |
|
|
|
if settings.use_actions: |
|
|
|
encoder_input_size += ( |
|
|
|
self._action_flattener.flattened_size + 1 |
|
|
|
) # + 1 is for done |
|
|
|
|
|
|
|
self.encoder = torch.nn.Sequential( |
|
|
|
linear_layer(encoder_input_size, settings.encoding_size), |
|
|
|
Swish(), |
|
|
|
linear_layer(settings.encoding_size, settings.encoding_size), |
|
|
|
Swish(), |
|
|
|
unencoded_size = ( |
|
|
|
self._action_flattener.flattened_size + 1 if settings.use_actions else 0 |
|
|
|
) # +1 is for dones |
|
|
|
self.encoder = NetworkBody( |
|
|
|
specs.observation_shapes, encoder_settings, unencoded_size |
|
|
|
) |
|
|
|
|
|
|
|
estimator_input_size = settings.encoding_size |
|
|
|
|
|
|
torch.as_tensor(mini_batch["actions"], dtype=torch.float) |
|
|
|
) |
|
|
|
|
|
|
|
def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|
|
|
def get_state_inputs( |
|
|
|
self, mini_batch: AgentBuffer |
|
|
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
|
|
n_vis = len(self._state_encoder.visual_processors) |
|
|
|
hidden, _ = self._state_encoder.forward( |
|
|
|
vec_inputs=[torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float)], |
|
|
|
vis_inputs=[ |
|
|
|
torch.as_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) |
|
|
|
for i in range(n_vis) |
|
|
|
], |
|
|
|
n_vis = len(self.encoder.visual_processors) |
|
|
|
n_vec = len(self.encoder.vector_processors) |
|
|
|
vec_inputs = ( |
|
|
|
[ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)] |
|
|
|
if n_vec > 0 |
|
|
|
else [] |
|
|
|
return hidden |
|
|
|
vis_inputs = [ |
|
|
|
ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) |
|
|
|
for i in range(n_vis) |
|
|
|
] |
|
|
|
return vec_inputs, vis_inputs |
|
|
|
|
|
|
|
def compute_estimate( |
|
|
|
self, mini_batch: AgentBuffer, use_vail_noise: bool = False |
|
|
|
|
|
|
:param use_vail_noise: Only when using VAIL : If true, will sample the code, if |
|
|
|
false, will return the mean of the code. |
|
|
|
""" |
|
|
|
encoder_input = self.get_state_encoding(mini_batch) |
|
|
|
vec_inputs, vis_inputs = self.get_state_inputs(mini_batch) |
|
|
|
encoder_input = torch.cat([encoder_input, actions, dones], dim=1) |
|
|
|
hidden = self.encoder(encoder_input) |
|
|
|
action_inputs = torch.cat([actions, dones], dim=1) |
|
|
|
hidden, _ = self.encoder(vec_inputs, vis_inputs, action_inputs) |
|
|
|
else: |
|
|
|
hidden, _ = self.encoder(vec_inputs, vis_inputs) |
|
|
|
z_mu: Optional[torch.Tensor] = None |
|
|
|
if self._settings.use_vail: |
|
|
|
z_mu = self._z_mu_layer(hidden) |
|
|
|
|
|
|
Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. |
|
|
|
for off-policy. Compute gradients w.r.t randomly interpolated input. |
|
|
|
""" |
|
|
|
policy_obs = self.get_state_encoding(policy_batch) |
|
|
|
expert_obs = self.get_state_encoding(expert_batch) |
|
|
|
obs_epsilon = torch.rand(policy_obs.shape) |
|
|
|
encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs |
|
|
|
policy_vec_inputs, policy_vis_inputs = self.get_state_inputs(policy_batch) |
|
|
|
expert_vec_inputs, expert_vis_inputs = self.get_state_inputs(expert_batch) |
|
|
|
interp_vec_inputs = [] |
|
|
|
for policy_vec_input, expert_vec_input in zip( |
|
|
|
policy_vec_inputs, expert_vec_inputs |
|
|
|
): |
|
|
|
obs_epsilon = torch.rand(policy_vec_input.shape) |
|
|
|
interp_vec_input = ( |
|
|
|
obs_epsilon * policy_vec_input + (1 - obs_epsilon) * expert_vec_input |
|
|
|
) |
|
|
|
interp_vec_input.requires_grad = True # For gradient calculation |
|
|
|
interp_vec_inputs.append(interp_vec_input) |
|
|
|
interp_vis_inputs = [] |
|
|
|
for policy_vis_input, expert_vis_input in zip( |
|
|
|
policy_vis_inputs, expert_vis_inputs |
|
|
|
): |
|
|
|
obs_epsilon = torch.rand(policy_vis_input.shape) |
|
|
|
interp_vis_input = ( |
|
|
|
obs_epsilon * policy_vis_input + (1 - obs_epsilon) * expert_vis_input |
|
|
|
) |
|
|
|
interp_vis_input.requires_grad = True # For gradient calculation |
|
|
|
interp_vis_inputs.append(interp_vis_input) |
|
|
|
if self._settings.use_actions: |
|
|
|
policy_action = self.get_action_input(policy_batch) |
|
|
|
expert_action = self.get_action_input(expert_batch) |
|
|
|
|
|
|
expert_batch["done"], dtype=torch.float |
|
|
|
).unsqueeze(1) |
|
|
|
dones_epsilon = torch.rand(policy_dones.shape) |
|
|
|
encoder_input = torch.cat( |
|
|
|
action_inputs = torch.cat( |
|
|
|
encoder_input, |
|
|
|
action_epsilon * policy_action |
|
|
|
+ (1 - action_epsilon) * expert_action, |
|
|
|
dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, |
|
|
|
|
|
|
hidden = self.encoder(encoder_input) |
|
|
|
action_inputs.requires_grad = True |
|
|
|
hidden, _ = self.encoder( |
|
|
|
interp_vec_inputs, interp_vis_inputs, action_inputs |
|
|
|
) |
|
|
|
encoder_input = tuple( |
|
|
|
interp_vec_inputs + interp_vis_inputs + [action_inputs] |
|
|
|
) |
|
|
|
else: |
|
|
|
hidden, _ = self.encoder(interp_vec_inputs, interp_vis_inputs) |
|
|
|
encoder_input = tuple(interp_vec_inputs + interp_vis_inputs) |
|
|
|
|
|
|
|
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0] |
|
|
|
# Norm's gradient could be NaN at 0. Use our own safe_norm |
|
|
|
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt() |