|
|
|
|
|
|
from mlagents.tf_utils import tf |
|
|
|
|
|
|
|
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork |
|
|
|
from mlagents.trainers.models import LearningRateSchedule, EncoderType, LearningModel |
|
|
|
from mlagents.trainers.models import LearningRateSchedule, EncoderType, ModelUtils |
|
|
|
from mlagents.trainers.common.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
|
|
|
) |
|
|
|
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value) |
|
|
|
self.m_size = 3 * self.policy.m_size |
|
|
|
self.create_inputs_and_outputs() |
|
|
|
self.learning_rate = LearningModel.create_learning_rate( |
|
|
|
self._create_inputs_and_outputs() |
|
|
|
self.learning_rate = ModelUtils.create_learning_rate( |
|
|
|
lr_schedule, lr, self.policy.global_step, int(max_step) |
|
|
|
) |
|
|
|
self._create_losses( |
|
|
|
|
|
|
stream_names, |
|
|
|
discrete=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
self.create_sac_optimizers() |
|
|
|
self._create_sac_optimizer_ops() |
|
|
|
|
|
|
|
self.selected_actions = ( |
|
|
|
self.policy.selected_actions |
|
|
|
|
|
|
"learning_rate": self.learning_rate, |
|
|
|
} |
|
|
|
|
|
|
|
def create_inputs_and_outputs(self) -> None: |
|
|
|
def _create_inputs_and_outputs(self) -> None: |
|
|
|
""" |
|
|
|
Assign the higher-level SACModel's inputs and outputs to those of its policy or |
|
|
|
target network. |
|
|
|
|
|
|
|
|
|
|
for name in stream_names: |
|
|
|
if discrete: |
|
|
|
_branched_mpq1 = self.apply_as_branches( |
|
|
|
_branched_mpq1 = self._apply_as_branches( |
|
|
|
self.policy_network.q1_pheads[name] * discrete_action_probs |
|
|
|
) |
|
|
|
branched_mpq1 = tf.stack( |
|
|
|
|
|
|
) |
|
|
|
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0) |
|
|
|
|
|
|
|
_branched_mpq2 = self.apply_as_branches( |
|
|
|
_branched_mpq2 = self._apply_as_branches( |
|
|
|
self.policy_network.q2_pheads[name] * discrete_action_probs |
|
|
|
) |
|
|
|
branched_mpq2 = tf.stack( |
|
|
|
|
|
|
|
|
|
|
if discrete: |
|
|
|
# We need to break up the Q functions by branch, and update them individually. |
|
|
|
branched_q1_stream = self.apply_as_branches( |
|
|
|
branched_q1_stream = self._apply_as_branches( |
|
|
|
branched_q2_stream = self.apply_as_branches( |
|
|
|
branched_q2_stream = self._apply_as_branches( |
|
|
|
self.policy.action_oh * q2_streams[name] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
self.ent_coef = tf.exp(self.log_ent_coef) |
|
|
|
if discrete: |
|
|
|
# We also have to do a different entropy and target_entropy per branch. |
|
|
|
branched_per_action_ent = self.apply_as_branches(per_action_entropy) |
|
|
|
branched_per_action_ent = self._apply_as_branches(per_action_entropy) |
|
|
|
branched_ent_sums = tf.stack( |
|
|
|
[ |
|
|
|
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te |
|
|
|
|
|
|
# Same with policy loss, we have to do the loss per branch and average them, |
|
|
|
# so that larger branches don't get more weight. |
|
|
|
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q |
|
|
|
branched_q_term = self.apply_as_branches( |
|
|
|
branched_q_term = self._apply_as_branches( |
|
|
|
discrete_action_probs * self.policy_network.q1_p |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.entropy = self.policy_network.entropy |
|
|
|
|
|
|
|
def apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]: |
|
|
|
def _apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]: |
|
|
|
""" |
|
|
|
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per |
|
|
|
action branch |
|
|
|
|
|
|
] |
|
|
|
return branches_logits |
|
|
|
|
|
|
|
def create_sac_optimizers(self) -> None: |
|
|
|
def _create_sac_optimizer_ops(self) -> None: |
|
|
|
""" |
|
|
|
Creates the Adam optimizers and update ops for SAC, including |
|
|
|
the policy, value, and entropy updates, as well as the target network update. |
|
|
|
|
|
|
indexed by name. If none, don't update the reward signals. |
|
|
|
:return: Output from update process. |
|
|
|
""" |
|
|
|
feed_dict = self.construct_feed_dict(self.policy, batch, num_sequences) |
|
|
|
feed_dict = self._construct_feed_dict(self.policy, batch, num_sequences) |
|
|
|
stats_needed = self.stats_name_to_update_name |
|
|
|
update_stats: Dict[str, float] = {} |
|
|
|
update_vals = self._execute_model(feed_dict, self.update_dict) |
|
|
|
|
|
|
update_dict.update(self.reward_signals[name].update_dict) |
|
|
|
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|
|
|
|
|
|
|
def construct_feed_dict( |
|
|
|
def _construct_feed_dict( |
|
|
|
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int |
|
|
|
) -> Dict[tf.Tensor, Any]: |
|
|
|
""" |
|
|
|