浏览代码

Remove and rename tf_optimizer

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
1cfc461a
共有 4 个文件被更改,包括 4 次插入6 次删除
  1. 1
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 2
      ml-agents/mlagents/trainers/common/tf_optimizer.py
  3. 6
      ml-agents/mlagents/trainers/sac/optimizer.py
  4. 1
      ml-agents/mlagents/trainers/tf_policy.py

1
ml-agents/mlagents/trainers/common/nn_policy.py


:param resample: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, brain, trainer_params, load)
self.tf_optimizer: Optional[tf.train.Optimizer] = None
self.grads = None
self.update_batch: Optional[tf.Operation] = None
num_layers = trainer_params["num_layers"]

2
ml-agents/mlagents/trainers/common/tf_optimizer.py


)
self.update_dict.update(self.reward_signals[reward_signal].update_dict)
def create_tf_optimizer(
def create_optimizer_op(
self, learning_rate: float, name: str = "Adam"
) -> tf.train.Optimizer:
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name)

6
ml-agents/mlagents/trainers/sac/optimizer.py


Creates the Adam optimizers and update ops for SAC, including
the policy, value, and entropy updates, as well as the target network update.
"""
policy_optimizer = self.create_tf_optimizer(
policy_optimizer = self.create_optimizer_op(
entropy_optimizer = self.create_tf_optimizer(
entropy_optimizer = self.create_optimizer_op(
value_optimizer = self.create_tf_optimizer(
value_optimizer = self.create_optimizer_op(
learning_rate=self.learning_rate, name="sac_value_opt"
)

1
ml-agents/mlagents/trainers/tf_policy.py


)
tf.set_random_seed(seed)
self.saver = None
self.tf_optimizer = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]

正在加载...
取消
保存