|
|
|
|
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers import __version__ |
|
|
|
from mlagents.tf_utils.globals import get_rank |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
self.graph = None |
|
|
|
self.sess = None |
|
|
|
self.tf_saver = None |
|
|
|
self.rank = get_rank() |
|
|
|
|
|
|
|
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None: |
|
|
|
if isinstance(module, TFPolicy): |
|
|
|