|
|
|
|
|
|
log_probs: ActionLogProbs, |
|
|
|
expert_actions: torch.Tensor, |
|
|
|
) -> torch.Tensor: |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
bc_loss = torch.nn.functional.mse_loss( |
|
|
|
selected_actions.continuous_tensor, expert_actions |
|
|
|
bc_loss = 0 |
|
|
|
if self.policy.action_spec.continuous_size > 0: |
|
|
|
bc_loss += torch.nn.functional.mse_loss( |
|
|
|
selected_actions.continuous_tensor, expert_actions.continuous_tensor |
|
|
|
) |
|
|
|
if self.policy.action_spec.discrete_size > 0: |
|
|
|
one_hot_expert_actions = ModelUtils.actions_to_onehot( |
|
|
|
expert_actions.discrete_tensor, |
|
|
|
self.policy.action_spec.discrete_branches, |
|
|
|
else: |
|
|
|
|
|
|
|
bc_loss = torch.mean( |
|
|
|
bc_loss += torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
torch.sum( |
|
|
|
|
|
|
) |
|
|
|
for log_prob_branch, expert_actions_branch in zip( |
|
|
|
log_prob_branches, expert_actions |
|
|
|
log_prob_branches, one_hot_expert_actions |
|
|
|
print(bc_loss) |
|
|
|
return bc_loss |
|
|
|
|
|
|
|
def _update_batch( |
|
|
|
|
|
|
""" |
|
|
|
vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])] |
|
|
|
act_masks = None |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
expert_actions = ModelUtils.list_to_tensor( |
|
|
|
mini_batch_demo["continuous_action"] |
|
|
|
) |
|
|
|
else: |
|
|
|
raw_expert_actions = ModelUtils.list_to_tensor( |
|
|
|
mini_batch_demo["discrete_action"], dtype=torch.long |
|
|
|
) |
|
|
|
expert_actions = ModelUtils.actions_to_onehot( |
|
|
|
raw_expert_actions, self.policy.act_size |
|
|
|
) |
|
|
|
expert_actions = AgentAction.from_dict(mini_batch_demo) |
|
|
|
if self.policy.action_spec.discrete_size > 0: |
|
|
|
act_masks = ModelUtils.list_to_tensor( |
|
|
|
np.ones( |
|
|
|
( |
|
|
|
|
|
|
else: |
|
|
|
vis_obs = [] |
|
|
|
|
|
|
|
selected_actions, log_probs, _, _ = self.policy.sample_actions( |
|
|
|
selected_actions, log_probs, _, _, _ = self.policy.sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|