浏览代码

Fix save snapshot bug in ghost trainer (#3722)

/develop/add-fire
GitHub 4 年前
当前提交
9c8142c2
共有 1 个文件被更改,包括 11 次插入1 次删除
  1. 12
      ml-agents/mlagents/trainers/ghost/trainer.py

12
ml-agents/mlagents/trainers/ghost/trainer.py


def save_model(self, name_behavior_id: str) -> None:
"""
Forwarding call to wrapped trainers save_model
Loads the latest policy weights, saves it, then reloads
the current policy weights before resuming training.
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
policy = self.trainer.get_policy(brain_name)
reload_weights = policy.get_weights()
# save current snapshot to policy
policy.load_weights(self.current_policy_snapshot[brain_name])
# reload
policy.load_weights(reload_weights)
First loads the current snapshot.
First loads the latest snapshot.
"""
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name

正在加载...
取消
保存