|
|
|
|
|
|
from typing import Dict, Any |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from .model_torch import TorchBCModel |
|
|
|
from mlagents.trainers.demo_loader import demo_to_buffer |
|
|
|
from mlagents.trainers.settings import BehavioralCloningSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
|
|
|
self.policy = policy |
|
|
|
self.current_lr = policy_learning_rate * settings.strength |
|
|
|
params = list(self.policy.actor_critic.parameters()) |
|
|
|
self.optimizer = torch.optim.Adam( |
|
|
|
params, lr=self.current_lr |
|
|
|
) |
|
|
|
self.optimizer = torch.optim.Adam(params, lr=self.current_lr) |
|
|
|
|
|
|
|
_, self.demonstration_buffer = demo_to_buffer( |
|
|
|
settings.demo_path, policy.sequence_length, policy.behavior_spec |
|
|
|
|
|
|
self.has_updated = False |
|
|
|
self.use_recurrent = self.policy.use_recurrent |
|
|
|
self.samples_per_update = settings.samples_per_update |
|
|
|
self.out_dict = { |
|
|
|
"loss": self.model.loss, |
|
|
|
"update": self.model.update_batch, |
|
|
|
"learning_rate": self.model.annealed_learning_rate, |
|
|
|
} |
|
|
|
# self.out_dict = { |
|
|
|
# "loss": self.model.loss, |
|
|
|
# "update": self.model.update_batch, |
|
|
|
# "learning_rate": self.model.annealed_learning_rate, |
|
|
|
# } |
|
|
|
|
|
|
|
def update(self) -> Dict[str, Any]: |
|
|
|
""" |
|
|
|
|
|
|
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end) |
|
|
|
run_out = self._update_batch(mini_batch_demo, self.n_sequences) |
|
|
|
loss = run_out["loss"] |
|
|
|
#self.current_lr = update_stats["learning_rate"] |
|
|
|
# self.current_lr = update_stats["learning_rate"] |
|
|
|
def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions): |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions) |
|
|
|
else: |
|
|
|
# TODO: add epsilon to log_probs |
|
|
|
log_prob_branches = ModelUtils.break_into_branches( |
|
|
|
log_probs, self.policy.act_size |
|
|
|
) |
|
|
|
bc_loss = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
-torch.nn.functional.log_softmax(log_prob_branch, dim=0) |
|
|
|
* expert_actions_branch |
|
|
|
for log_prob_branch, expert_actions_branch in zip( |
|
|
|
log_prob_branches, expert_actions |
|
|
|
) |
|
|
|
] |
|
|
|
) |
|
|
|
) |
|
|
|
return bc_loss |
|
|
|
|
|
|
|
def _update_batch( |
|
|
|
self, mini_batch_demo: Dict[str, Any], n_sequences: int |
|
|
|
) -> Dict[str, Any]: |
|
|
|
|
|
|
vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])] |
|
|
|
act_masks = ModelUtils.list_to_tensor(mini_batch_demo["action_mask"]) |
|
|
|
act_masks = None |
|
|
|
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"]).unsqueeze(-1) |
|
|
|
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"]) |
|
|
|
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"], dtype=torch.long) |
|
|
|
# one hot |
|
|
|
raw_expert_actions = ModelUtils.list_to_tensor( |
|
|
|
mini_batch_demo["actions"], dtype=torch.long |
|
|
|
) |
|
|
|
expert_actions = ModelUtils.actions_to_onehot( |
|
|
|
raw_expert_actions, self.policy.act_size |
|
|
|
) |
|
|
|
act_masks = ModelUtils.list_to_tensor( |
|
|
|
np.ones( |
|
|
|
( |
|
|
|
self.n_sequences * self.policy.sequence_length, |
|
|
|
sum(self.policy.behavior_spec.discrete_action_branches), |
|
|
|
), |
|
|
|
dtype=np.float32, |
|
|
|
) |
|
|
|
) |
|
|
|
for i in range(0, len(mini_batch_demo["memory"]), self.policy.sequence_length) |
|
|
|
for i in range( |
|
|
|
0, len(mini_batch_demo["memory"]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
|
|
|
for idx, _ in enumerate( |
|
|
|
self.policy.actor_critic.network_body.visual_encoders |
|
|
|
): |
|
|
|
vis_ob = ModelUtils.list_to_tensor(mini_batch_demo["visual_obs%d" % idx]) |
|
|
|
vis_ob = ModelUtils.list_to_tensor( |
|
|
|
mini_batch_demo["visual_obs%d" % idx] |
|
|
|
) |
|
|
|
selected_actions, log_probs, entropies, values, memories = self.policy.sample_actions( |
|
|
|
selected_actions, all_log_probs, entropies, values, memories = self.policy.sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
bc_loss = self._behavioral_cloning_loss(selected_actions, all_log_probs, expert_actions) |
|
|
|
bc_loss = self._behavioral_cloning_loss( |
|
|
|
selected_actions, all_log_probs, expert_actions |
|
|
|
) |
|
|
|
run_out = { |
|
|
|
"loss": bc_loss.detach().cpu().numpy(), |
|
|
|
} |
|
|
|
run_out = {"loss": bc_loss.detach().cpu().numpy()} |
|
|
|
|
|
|
|
def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
loss = (selected_actions - expert_actions) ** 2 |
|
|
|
else: |
|
|
|
loss = -torch.log(torch.nn.Softmax(log_probs) + 1e-7) * expert_actions |
|
|
|
bc_loss = torch.mean(loss) |
|
|
|
return bc_loss |