浏览代码

Add some typing to optimizer

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
8e300036
共有 2 个文件被更改,包括 16 次插入10 次删除
  1. 8
      ml-agents/mlagents/trainers/sac/network.py
  2. 18
      ml-agents/mlagents/trainers/sac/optimizer.py

8
ml-agents/mlagents/trainers/sac/network.py


self.critic_vars = None
self.policy_vars = None
self.q1_heads: Optional[Dict[str, tf.Tensor]] = None
self.q2_heads: Optional[Dict[str, tf.Tensor]] = None
self.q1_pheads: Optional[Dict[str, tf.Tensor]] = None
self.q2_pheads: Optional[Dict[str, tf.Tensor]] = None
self.q1_heads: Dict[str, tf.Tensor] = None
self.q2_heads: Dict[str, tf.Tensor] = None
self.q1_pheads: Dict[str, tf.Tensor] = None
self.q2_pheads: Dict[str, tf.Tensor] = None
self.policy = policy

18
ml-agents/mlagents/trainers/sac/optimizer.py


self.policy_network.q1_heads,
self.policy_network.q2_heads,
lr,
max_step,
int(max_step),
stream_names,
discrete=not self.policy.use_continuous_act,
)

if self.policy.use_recurrent:
self.policy.inference_dict["optimizer_memory_out"] = self.memory_out
def create_inputs_and_outputs(self):
def create_inputs_and_outputs(self) -> None:
"""
Assign the higher-level SACModel's inputs and outputs to those of its policy or
target network.

self.next_memory_in = self.target_network.memory_in
def create_losses(
self, q1_streams, q2_streams, lr, max_step, stream_names, discrete=False
):
self,
q1_streams: Dict[str, tf.Tensor],
q2_streams: Dict[str, tf.Tensor],
lr: tf.Tensor,
max_step: int,
stream_names: List[str],
discrete: bool = False,
) -> None:
"""
Creates training-specific Tensorflow ops for SAC models.
:param q1_streams: Q1 streams from policy network

self.entropy = self.policy_network.entropy
def apply_as_branches(self, concat_logits):
def apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]:
"""
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per
action branch

]
return branches_logits
def create_sac_optimizers(self):
def create_sac_optimizers(self) -> None:
"""
Creates the Adam optimizers and update ops for SAC, including
the policy, value, and entropy updates, as well as the target network update.

正在加载...
取消
保存