|
|
|
|
|
|
import numpy as np |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
from typing import Dict, List, Mapping, NamedTuple, cast, Tuple, Optional |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs |
|
|
|
from mlagents.trainers.torch.agent_action import AgentAction |
|
|
|
from mlagents.trainers.torch.action_log_probs import ActionLogProbs |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|
|
|
from contextlib import ExitStack |
|
|
|
|
|
|
action_spec: ActionSpec, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.action_spec = action_spec |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
self.act_size = self.action_spec.continuous_size |
|
|
|
num_value_outs = 1 |
|
|
|
num_action_ins = self.act_size |
|
|
|
num_value_outs = max(sum(action_spec.discrete_branches), 1) |
|
|
|
num_action_ins = int(action_spec.continuous_size) |
|
|
|
else: |
|
|
|
self.act_size = self.action_spec.discrete_branches |
|
|
|
num_value_outs = sum(self.act_size) |
|
|
|
num_action_ins = 0 |
|
|
|
self.q1_network = ValueNetwork( |
|
|
|
stream_names, |
|
|
|
observation_shapes, |
|
|
|
|
|
|
) |
|
|
|
return q1_out, q2_out |
|
|
|
|
|
|
|
class TargetEntropy(NamedTuple): |
|
|
|
|
|
|
|
discrete: List[float] = [] # One per branch |
|
|
|
continuous: float = 0.0 |
|
|
|
|
|
|
|
class LogEntCoef(nn.Module): |
|
|
|
def __init__(self, discrete, continuous): |
|
|
|
super().__init__() |
|
|
|
self.discrete = discrete |
|
|
|
self.continuous = continuous |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
super().__init__(policy, trainer_params) |
|
|
|
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) |
|
|
|
|
|
|
self.policy = policy |
|
|
|
self.act_size = policy.act_size |
|
|
|
policy_network_settings = policy.network_settings |
|
|
|
|
|
|
|
self.tau = hyperparameters.tau |
|
|
|
|
|
|
name: int(not self.reward_signals[name].ignore_done) |
|
|
|
for name in self.stream_names |
|
|
|
} |
|
|
|
self._action_spec = self.policy.behavior_spec.action_spec |
|
|
|
self.policy.behavior_spec.action_spec, |
|
|
|
self._action_spec, |
|
|
|
) |
|
|
|
|
|
|
|
self.target_network = ValueNetwork( |
|
|
|
|
|
|
self.policy.actor_critic.critic, self.target_network, 1.0 |
|
|
|
) |
|
|
|
|
|
|
|
self._log_ent_coef = torch.nn.Parameter( |
|
|
|
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|
|
|
# We create one entropy coefficient per action, whether discrete or continuous. |
|
|
|
_disc_log_ent_coef = torch.nn.Parameter( |
|
|
|
torch.log( |
|
|
|
torch.as_tensor( |
|
|
|
[self.init_entcoef] * len(self._action_spec.discrete_branches) |
|
|
|
) |
|
|
|
), |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self.target_entropy = torch.as_tensor( |
|
|
|
-1 |
|
|
|
* self.continuous_target_entropy_scale |
|
|
|
* np.prod(self.act_size[0]).astype(np.float32) |
|
|
|
) |
|
|
|
else: |
|
|
|
self.target_entropy = [ |
|
|
|
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|
|
|
for i in self.act_size |
|
|
|
] |
|
|
|
|
|
|
|
_cont_log_ent_coef = torch.nn.Parameter( |
|
|
|
torch.log( |
|
|
|
torch.as_tensor([self.init_entcoef] * self._action_spec.continuous_size) |
|
|
|
), |
|
|
|
requires_grad=True, |
|
|
|
) |
|
|
|
self._log_ent_coef = TorchSACOptimizer.LogEntCoef( |
|
|
|
discrete=_disc_log_ent_coef, continuous=_cont_log_ent_coef |
|
|
|
) |
|
|
|
_cont_target = ( |
|
|
|
-1 |
|
|
|
* self.continuous_target_entropy_scale |
|
|
|
* np.prod(self._action_spec.continuous_size).astype(np.float32) |
|
|
|
) |
|
|
|
_disc_target = [ |
|
|
|
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|
|
|
for i in self._action_spec.discrete_branches |
|
|
|
] |
|
|
|
self.target_entropy = TorchSACOptimizer.TargetEntropy( |
|
|
|
continuous=_cont_target, discrete=_disc_target |
|
|
|
) |
|
|
|
self.policy.actor_critic.distribution.parameters() |
|
|
|
self.policy.actor_critic.action_model.parameters() |
|
|
|
) |
|
|
|
value_params = list(self.value_network.parameters()) + list( |
|
|
|
self.policy.actor_critic.critic.parameters() |
|
|
|
|
|
|
value_params, lr=hyperparameters.learning_rate |
|
|
|
) |
|
|
|
self.entropy_optimizer = torch.optim.Adam( |
|
|
|
[self._log_ent_coef], lr=hyperparameters.learning_rate |
|
|
|
self._log_ent_coef.parameters(), lr=hyperparameters.learning_rate |
|
|
|
) |
|
|
|
self._move_to_device(default_device()) |
|
|
|
|
|
|
|
|
|
|
q1p_out: Dict[str, torch.Tensor], |
|
|
|
q2p_out: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
discrete: bool, |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
for name in values.keys(): |
|
|
|
if not discrete: |
|
|
|
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|
|
|
else: |
|
|
|
action_probs = log_probs.all_discrete_tensor.exp() |
|
|
|
_branched_q1p = ModelUtils.break_into_branches( |
|
|
|
q1p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_branched_q2p = ModelUtils.break_into_branches( |
|
|
|
q2p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_br, dim=1, keepdim=True) |
|
|
|
for _br in _branched_q1p |
|
|
|
] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_q2p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_br, dim=1, keepdim=True) |
|
|
|
for _br in _branched_q2p |
|
|
|
] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_cont_ent_coef = self._log_ent_coef.continuous.exp() |
|
|
|
_disc_ent_coef = self._log_ent_coef.discrete.exp() |
|
|
|
for name in values.keys(): |
|
|
|
if self._action_spec.discrete_size <= 0: |
|
|
|
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|
|
|
else: |
|
|
|
disc_action_probs = log_probs.all_discrete_tensor.exp() |
|
|
|
_branched_q1p = ModelUtils.break_into_branches( |
|
|
|
q1p_out[name] * disc_action_probs, |
|
|
|
self._action_spec.discrete_branches, |
|
|
|
) |
|
|
|
_branched_q2p = ModelUtils.break_into_branches( |
|
|
|
q2p_out[name] * disc_action_probs, |
|
|
|
self._action_spec.discrete_branches, |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_q2p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
if not discrete: |
|
|
|
if self._action_spec.discrete_size <= 0: |
|
|
|
_ent_coef * log_probs.continuous_tensor, dim=1 |
|
|
|
_cont_ent_coef * log_probs.continuous_tensor, dim=1 |
|
|
|
) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|
|
|
|
|
|
disc_log_probs = log_probs.all_discrete_tensor |
|
|
|
log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(), |
|
|
|
self.act_size, |
|
|
|
disc_log_probs * disc_log_probs.exp(), |
|
|
|
self._action_spec.discrete_branches, |
|
|
|
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|
|
|
torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True) |
|
|
|
for i, _lp in enumerate(branched_per_action_ent) |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
branched_ent_bonus, axis=0 |
|
|
|
) |
|
|
|
# Add continuous entropy bonus to minimum Q |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
torch.sum( |
|
|
|
_cont_ent_coef * log_probs.continuous_tensor, |
|
|
|
dim=1, |
|
|
|
keepdim=True, |
|
|
|
) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), |
|
|
|
loss_masks, |
|
|
|
|
|
|
log_probs: ActionLogProbs, |
|
|
|
q1p_outs: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
discrete: bool, |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
_cont_ent_coef, _disc_ent_coef = ( |
|
|
|
self._log_ent_coef.continuous, |
|
|
|
self._log_ent_coef.discrete, |
|
|
|
) |
|
|
|
_cont_ent_coef = _cont_ent_coef.exp() |
|
|
|
_disc_ent_coef = _disc_ent_coef.exp() |
|
|
|
|
|
|
|
if not discrete: |
|
|
|
mean_q1 = mean_q1.unsqueeze(1) |
|
|
|
batch_policy_loss = torch.mean( |
|
|
|
_ent_coef * log_probs.continuous_tensor - mean_q1, dim=1 |
|
|
|
) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
else: |
|
|
|
action_probs = log_probs.all_discrete_tensor.exp() |
|
|
|
batch_policy_loss = 0 |
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
disc_log_probs = log_probs.all_discrete_tensor |
|
|
|
disc_action_probs = disc_log_probs.exp() |
|
|
|
log_probs.all_discrete_tensor * action_probs, self.act_size |
|
|
|
disc_log_probs * disc_action_probs, self._action_spec.discrete_branches |
|
|
|
mean_q1 * action_probs, self.act_size |
|
|
|
mean_q1 * disc_action_probs, self._action_spec.discrete_branches |
|
|
|
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) |
|
|
|
torch.sum(_disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False) |
|
|
|
for i, (_lp, _qt) in enumerate( |
|
|
|
zip(branched_per_action_ent, branched_q_term) |
|
|
|
) |
|
|
|
|
|
|
batch_policy_loss = torch.squeeze(branched_policy_loss) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
batch_policy_loss += torch.sum(branched_policy_loss, dim=1) |
|
|
|
all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1) |
|
|
|
else: |
|
|
|
all_mean_q1 = mean_q1 |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
cont_log_probs = log_probs.continuous_tensor |
|
|
|
batch_policy_loss += torch.mean( |
|
|
|
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1 |
|
|
|
) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
|
|
|
|
self, log_probs: ActionLogProbs, loss_masks: torch.Tensor, discrete: bool |
|
|
|
self, log_probs: ActionLogProbs, loss_masks: torch.Tensor |
|
|
|
if not discrete: |
|
|
|
with torch.no_grad(): |
|
|
|
target_current_diff = torch.sum( |
|
|
|
log_probs.continuous_tensor + self.target_entropy, dim=1 |
|
|
|
) |
|
|
|
entropy_loss = -1 * ModelUtils.masked_mean( |
|
|
|
self._log_ent_coef * target_current_diff, loss_masks |
|
|
|
) |
|
|
|
else: |
|
|
|
_cont_ent_coef, _disc_ent_coef = ( |
|
|
|
self._log_ent_coef.continuous, |
|
|
|
self._log_ent_coef.discrete, |
|
|
|
) |
|
|
|
entropy_loss = 0 |
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
# Break continuous into separate branch |
|
|
|
disc_log_probs = log_probs.all_discrete_tensor |
|
|
|
log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(), |
|
|
|
self.act_size, |
|
|
|
disc_log_probs * disc_log_probs.exp(), |
|
|
|
self._action_spec.discrete_branches, |
|
|
|
branched_per_action_ent, self.target_entropy |
|
|
|
branched_per_action_ent, self.target_entropy.discrete |
|
|
|
) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
|
|
|
) |
|
|
|
entropy_loss = -1 * ModelUtils.masked_mean( |
|
|
|
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks |
|
|
|
entropy_loss += -1 * ModelUtils.masked_mean( |
|
|
|
torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks |
|
|
|
) |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
with torch.no_grad(): |
|
|
|
cont_log_probs = log_probs.continuous_tensor |
|
|
|
target_current_diff = torch.sum( |
|
|
|
cont_log_probs + self.target_entropy.continuous, dim=1 |
|
|
|
) |
|
|
|
# We update all the _cont_ent_coef as one block |
|
|
|
entropy_loss += -1 * ModelUtils.masked_mean( |
|
|
|
torch.mean(_cont_ent_coef) * target_current_diff, loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
return entropy_loss |
|
|
|
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
condensed_q_output = {} |
|
|
|
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) |
|
|
|
onehot_actions = ModelUtils.actions_to_onehot( |
|
|
|
discrete_actions, self._action_spec.discrete_branches |
|
|
|
) |
|
|
|
branched_q = ModelUtils.break_into_branches(item, self.act_size) |
|
|
|
branched_q = ModelUtils.break_into_branches( |
|
|
|
item, self._action_spec.discrete_branches |
|
|
|
) |
|
|
|
only_action_qs = torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_act * _q, dim=1, keepdim=True) |
|
|
|
|
|
|
value_estimates, _ = self.policy.actor_critic.critic_pass( |
|
|
|
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.continuous_tensor |
|
|
|
# Only need grad for q1, as that is used for policy. |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
sampled_actions.continuous_tensor, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q2_grad=False, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
squeezed_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
cont_sampled_actions = sampled_actions.continuous_tensor |
|
|
|
|
|
|
|
cont_actions = actions.continuous_tensor |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
cont_sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
cont_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
if self._action_spec.discrete_size > 0: |
|
|
|
disc_actions = actions.discrete_tensor |
|
|
|
q1_stream = self._condense_q_streams(q1_out, disc_actions) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, disc_actions) |
|
|
|
else: |
|
|
|
else: |
|
|
|
# For discrete, you don't need to backprop through the Q for the policy |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q1_grad=False, |
|
|
|
q2_grad=False, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_stream = self._condense_q_streams(q1_out, actions.discrete_tensor) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions.discrete_tensor) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
|
|
|
log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete |
|
|
|
log_probs, value_estimates, q1p_out, q2p_out, masks |
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks) |
|
|
|
|
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
|
|
|
|
|
|
|
"Losses/Value Loss": value_loss.item(), |
|
|
|
"Losses/Q1 Loss": q1_loss.item(), |
|
|
|
"Losses/Q2 Loss": q2_loss.item(), |
|
|
|
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), |
|
|
|
"Policy/Discrete Entropy Coeff": torch.mean( |
|
|
|
torch.exp(self._log_ent_coef.discrete) |
|
|
|
).item(), |
|
|
|
"Policy/Continuous Entropy Coeff": torch.mean( |
|
|
|
torch.exp(self._log_ent_coef.continuous) |
|
|
|
).item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
} |
|
|
|
|
|
|
|