|
|
|
|
|
|
def save_model(self, name_behavior_id: str) -> None: |
|
|
|
""" |
|
|
|
Forwarding call to wrapped trainers save_model |
|
|
|
Loads the latest policy weights, saves it, then reloads |
|
|
|
the current policy weights before resuming training. |
|
|
|
policy = self.trainer.get_policy(brain_name) |
|
|
|
reload_weights = policy.get_weights() |
|
|
|
# save current snapshot to policy |
|
|
|
policy.load_weights(self.current_policy_snapshot[brain_name]) |
|
|
|
self.trainer.save_model(name_behavior_id) |
|
|
|
# reload |
|
|
|
policy.load_weights(reload_weights) |
|
|
|
self.trainer.save_model(brain_name) |
|
|
|
First loads the latest snapshot. |
|
|
|
policy = self.trainer.get_policy(brain_name) |
|
|
|
policy.load_weights(self.current_policy_snapshot[brain_name]) |
|
|
|
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy: |
|
|
|
def create_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|
|
|
) -> TFPolicy: |
|
|
|
""" |
|
|
|
return self.trainer.create_policy(brain_parameters) |
|
|
|
|
|
|
|
def add_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Adds policy to trainer. The first policy encountered sets the wrapped |
|
|
|
The first policy encountered sets the wrapped |
|
|
|
:param name_behavior_id: Behavior ID that the policy should belong to. |
|
|
|
:param policy: Policy to associate with name_behavior_id. |
|
|
|
name_behavior_id = parsed_behavior_id.behavior_id |
|
|
|
policy = self.trainer.create_policy(parsed_behavior_id, brain_parameters) |
|
|
|
policy.create_tf_graph() |
|
|
|
policy.init_load_weights() |
|
|
|
self.policies[name_behavior_id] = policy |
|
|
|
policy.create_tf_graph() |
|
|
|
|
|
|
|
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id |
|
|
|
# for saving/swapping snapshots |
|
|
|
policy.init_load_weights() |
|
|
|
internal_trainer_policy = self.trainer.create_policy( |
|
|
|
parsed_behavior_id, brain_parameters |
|
|
|
) |
|
|
|
internal_trainer_policy.create_tf_graph() |
|
|
|
internal_trainer_policy.init_load_weights() |
|
|
|
] = policy.get_weights() |
|
|
|
] = internal_trainer_policy.get_weights() |
|
|
|
policy.load_weights(internal_trainer_policy.get_weights()) |
|
|
|
self.trainer.add_policy(parsed_behavior_id, policy) |
|
|
|
self.trainer.add_policy(parsed_behavior_id, internal_trainer_policy) |
|
|
|
return policy |
|
|
|
|
|
|
|
def add_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Adds policy to GhostTrainer. |
|
|
|
:param parsed_behavior_id: Behavior ID that the policy should belong to. |
|
|
|
:param policy: Policy to associate with name_behavior_id. |
|
|
|
""" |
|
|
|
name_behavior_id = parsed_behavior_id.behavior_id |
|
|
|
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id |
|
|
|
self.policies[name_behavior_id] = policy |
|
|
|
|
|
|
|
def get_policy(self, name_behavior_id: str) -> TFPolicy: |
|
|
|
""" |
|
|
|