浏览代码

return copy of state_dict

/develop/add-fire/ghost
Andrew Cohen 4 年前
当前提交
effdec13
共有 1 个文件被更改,包括 2 次插入1 次删除
  1. 3
      ml-agents/mlagents/trainers/policy/torch_policy.py

3
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List
import numpy as np
import torch
import copy
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.behavior_id_utils import get_global_agent_id

pass
def get_weights(self) -> List[np.ndarray]:
return self.actor_critic.state_dict()
return copy.deepcopy(self.actor_critic.state_dict())
def get_modules(self):
return {"Policy": self.actor_critic, "global_step": self.global_step}
正在加载...
取消
保存