浏览代码

fixed order of load weight/create tf graph in add_policy

/internal-policy-ghost
Andrew Cohen 4 年前
当前提交
b9179f0f
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 6
      ml-agents/mlagents/trainers/ghost/trainer.py

6
ml-agents/mlagents/trainers/ghost/trainer.py


team_id = parsed_behavior_id.team_id
self.controller.subscribe_team_id(team_id, self)
policy = self.create_policy(brain_parameters)
policy.init_load_weights()
policy.init_load_weights()
self.policies[name_behavior_id] = policy
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id

internal_trainer_policy = self.trainer.get_policy(
parsed_behavior_id.brain_name
)
internal_trainer_policy.init_load_weights()
internal_trainer_policy.init_load_weights()
self.current_policy_snapshot[
parsed_behavior_id.brain_name

policy.load(internal_trainer_policy.get_weights())
policy.load_weights(internal_trainer_policy.get_weights())
self._save_snapshot() # Need to save after trainer initializes policy
self._learning_team = self.controller.get_learning_team

正在加载...
取消
保存