浏览代码

Move learning rate reporting

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
ff607162
共有 3 个文件被更改,包括 4 次插入8 次删除
  1. 3
      ml-agents/mlagents/trainers/agent_processor.py
  2. 4
      ml-agents/mlagents/trainers/ppo/optimizer.py
  3. 5
      ml-agents/mlagents/trainers/sac/optimizer.py

3
ml-agents/mlagents/trainers/agent_processor.py


if take_action_outputs:
for _entropy in take_action_outputs["entropy"]:
self.stats_reporter.add_stat("Policy/Entropy", _entropy)
self.stats_reporter.add_stat(
"Policy/Learning Rate", take_action_outputs["learning_rate"]
)
terminated_agents: Set[str] = set()
# Make unique agent_ids that are global across workers

4
ml-agents/mlagents/trainers/ppo/optimizer.py


self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
"Policy/Learning Rate": "learning_rate",
}
if self.policy.use_recurrent:
self.m_size = self.policy.m_size

"value_loss": self.value_loss,
"policy_loss": self.abs_policy_loss,
"update_batch": self.update_batch,
"learning_rate": self.learning_rate,
# Add some stuff to inference dict from optimizer
self.policy.inference_dict["learning_rate"] = self.learning_rate
self.policy.initialize_or_load()
def create_cc_critic(

5
ml-agents/mlagents/trainers/sac/optimizer.py


"Losses/Q1 Loss": "q1_loss",
"Losses/Q2 Loss": "q2_loss",
"Policy/Entropy Coeff": "entropy_coef",
"Policy/Learning Rate": "learning_rate",
}
self.update_dict = {

"update_batch": self.update_batch_policy,
"update_value": self.update_batch_value,
"update_entropy": self.update_batch_entropy,
"learning_rate": self.learning_rate,
# Add some stuff to inference dict from optimizer
self.policy.inference_dict["learning_rate"] = self.learning_rate
def create_inputs_and_outputs(self) -> None:
"""

正在加载...
取消
保存