浏览代码

add sac checkpoint

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
01e60921
共有 4 个文件被更改,包括 16 次插入3 次删除
  1. 1
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 2
      ml-agents/mlagents/trainers/sac/optimizer.py
  3. 9
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  4. 7
      ml-agents/mlagents/trainers/sac/trainer.py

1
ml-agents/mlagents/trainers/policy/torch_policy.py


import os
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents.trainers.policy import Policy
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents_envs.timers import timed

2
ml-agents/mlagents/trainers/sac/optimizer.py


[self.policy.update_normalization_op, target_update_norm]
)
self.policy.initialize_or_load()
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",

9
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int
) -> Dict[str, float]:
return {}
def get_modules(self):
return {
"Optimizer:value_network": self.value_network,
"Optimizer:target_network": self.target_network,
"Optimizer:policy_optimizer": self.policy_optimizer ,
"Optimizer:value_optimizer": self.value_optimizer,
"Optimizer:entropy_optimizer": self.entropy_optimizer,
}

7
ml-agents/mlagents/trainers/sac/trainer.py


) # type: ignore
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None:
self.saver = self.create_saver(policy=policy)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.maybe_load()
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before

正在加载...
取消
保存