|
|
|
|
|
|
self.policy_network.q1_heads, |
|
|
|
self.policy_network.q2_heads, |
|
|
|
lr, |
|
|
|
max_step, |
|
|
|
int(max_step), |
|
|
|
stream_names, |
|
|
|
discrete=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.policy.inference_dict["optimizer_memory_out"] = self.memory_out |
|
|
|
|
|
|
|
def create_inputs_and_outputs(self): |
|
|
|
def create_inputs_and_outputs(self) -> None: |
|
|
|
""" |
|
|
|
Assign the higher-level SACModel's inputs and outputs to those of its policy or |
|
|
|
target network. |
|
|
|
|
|
|
self.next_memory_in = self.target_network.memory_in |
|
|
|
|
|
|
|
def create_losses( |
|
|
|
self, q1_streams, q2_streams, lr, max_step, stream_names, discrete=False |
|
|
|
): |
|
|
|
self, |
|
|
|
q1_streams: Dict[str, tf.Tensor], |
|
|
|
q2_streams: Dict[str, tf.Tensor], |
|
|
|
lr: tf.Tensor, |
|
|
|
max_step: int, |
|
|
|
stream_names: List[str], |
|
|
|
discrete: bool = False, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Creates training-specific Tensorflow ops for SAC models. |
|
|
|
:param q1_streams: Q1 streams from policy network |
|
|
|
|
|
|
|
|
|
|
self.entropy = self.policy_network.entropy |
|
|
|
|
|
|
|
def apply_as_branches(self, concat_logits): |
|
|
|
def apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]: |
|
|
|
""" |
|
|
|
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per |
|
|
|
action branch |
|
|
|
|
|
|
] |
|
|
|
return branches_logits |
|
|
|
|
|
|
|
def create_sac_optimizers(self): |
|
|
|
def create_sac_optimizers(self) -> None: |
|
|
|
""" |
|
|
|
Creates the Adam optimizers and update ops for SAC, including |
|
|
|
the policy, value, and entropy updates, as well as the target network update. |
|
|
|