|
|
|
|
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers import __version__ |
|
|
|
from mlagents.tf_utils.globals import get_rank |
|
|
|
from mlagents.tf_utils import global_values |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
self.graph = None |
|
|
|
self.sess = None |
|
|
|
self.tf_saver = None |
|
|
|
self.rank = get_rank() |
|
|
|
self.rank = global_values.get_rank() |
|
|
|
|
|
|
|
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None: |
|
|
|
if isinstance(module, TFPolicy): |
|
|
|
|
|
|
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps) |
|
|
|
else: |
|
|
|
policy.initialize() |
|
|
|
TFPolicy.broadcast_global_variables(0) |
|
|
|
TFPolicy.broadcast_global_variables |
|
|
|
|
|
|
|
def _load_graph( |
|
|
|
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False |
|
|
|