|
|
|
|
|
|
import numpy as np |
|
|
|
from typing import Dict |
|
|
|
from typing import Dict, NamedTuple |
|
|
|
from mlagents.torch_utils import torch, default_device |
|
|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
|
|
|
from mlagents.trainers.torch.networks import NetworkBody |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer |
|
|
|
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|
|
|
|
|
|
|
|
|
|
|
class ActionPredictionTuple(NamedTuple): |
|
|
|
continuous: torch.Tensor |
|
|
|
discrete: torch.Tensor |
|
|
|
|
|
|
|
|
|
|
|
class CuriosityRewardProvider(BaseRewardProvider): |
|
|
|
|
|
|
|
|
|
|
self._action_flattener = ModelUtils.ActionFlattener(self._action_spec) |
|
|
|
|
|
|
|
self.inverse_model_action_prediction = torch.nn.Sequential( |
|
|
|
LinearEncoder(2 * settings.encoding_size, 1, 256), |
|
|
|
linear_layer(256, self._action_flattener.flattened_size), |
|
|
|
self.inverse_model_action_encoding = torch.nn.Sequential( |
|
|
|
LinearEncoder(2 * settings.encoding_size, 1, 256) |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
self.continuous_action_prediction = linear_layer( |
|
|
|
256, self._action_spec.continuous_size |
|
|
|
) |
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
self.discrete_action_prediction = linear_layer( |
|
|
|
256, sum(self._action_spec.discrete_branches) |
|
|
|
) |
|
|
|
|
|
|
|
self.forward_model_next_state_prediction = torch.nn.Sequential( |
|
|
|
LinearEncoder( |
|
|
|
settings.encoding_size + self._action_flattener.flattened_size, 1, 256 |
|
|
|
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|
|
|
def predict_action(self, mini_batch: AgentBuffer) -> ActionPredictionTuple: |
|
|
|
""" |
|
|
|
In the continuous case, returns the predicted action. |
|
|
|
In the discrete case, returns the logits. |
|
|
|
|
|
|
) |
|
|
|
hidden = self.inverse_model_action_prediction(inverse_model_input) |
|
|
|
if self._action_spec.is_continuous(): |
|
|
|
return hidden |
|
|
|
else: |
|
|
|
|
|
|
|
continuous_pred = None |
|
|
|
discrete_pred = None |
|
|
|
hidden = self.inverse_model_action_encoding(inverse_model_input) |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
continuous_pred = self.continuous_action_prediction(hidden) |
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
raw_discrete_pred = self.discrete_action_prediction(hidden) |
|
|
|
hidden, self._action_spec.discrete_branches |
|
|
|
raw_discrete_pred, self._action_spec.discrete_branches |
|
|
|
return torch.cat(branches, dim=1) |
|
|
|
discrete_pred = torch.cat(branches, dim=1) |
|
|
|
return ActionPredictionTuple(continuous_pred, discrete_pred) |
|
|
|
|
|
|
|
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|
|
|
""" |
|
|
|
|
|
|
actions = AgentAction.from_dict(mini_batch) |
|
|
|
if self._action_spec.is_continuous(): |
|
|
|
action = actions.continuous_tensor |
|
|
|
else: |
|
|
|
action = torch.cat( |
|
|
|
ModelUtils.actions_to_onehot( |
|
|
|
actions.discrete_tensor, self._action_spec.discrete_branches |
|
|
|
), |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
flattened_action = self._action_flattener.forward(actions) |
|
|
|
(self.get_current_state(mini_batch), action), dim=1 |
|
|
|
(self.get_current_state(mini_batch), flattened_action), dim=1 |
|
|
|
) |
|
|
|
|
|
|
|
return self.forward_model_next_state_prediction(forward_model_input) |
|
|
|
|
|
|
""" |
|
|
|
predicted_action = self.predict_action(mini_batch) |
|
|
|
actions = AgentAction.from_dict(mini_batch) |
|
|
|
if self._action_spec.is_continuous(): |
|
|
|
sq_difference = (actions.continuous_tensor - predicted_action) ** 2 |
|
|
|
_inverse_loss = 0 |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
sq_difference = ( |
|
|
|
actions.continuous_tensor - predicted_action.continuous |
|
|
|
) ** 2 |
|
|
|
return torch.mean( |
|
|
|
_inverse_loss += torch.mean( |
|
|
|
ModelUtils.dynamic_partition( |
|
|
|
sq_difference, |
|
|
|
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), |
|
|
|
|
|
|
else: |
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
true_action = torch.cat( |
|
|
|
ModelUtils.actions_to_onehot( |
|
|
|
actions.discrete_tensor, self._action_spec.discrete_branches |
|
|
|
|
|
|
cross_entropy = torch.sum( |
|
|
|
-torch.log(predicted_action + self.EPSILON) * true_action, dim=1 |
|
|
|
-torch.log(predicted_action.discrete + self.EPSILON) * true_action, |
|
|
|
dim=1, |
|
|
|
return torch.mean( |
|
|
|
_inverse_loss += torch.mean( |
|
|
|
ModelUtils.dynamic_partition( |
|
|
|
cross_entropy, |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
|
|
|
)[1] |
|
|
|
) |
|
|
|
return _inverse_loss |
|
|
|
|
|
|
|
def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|
|
|
""" |
|
|
|