浏览代码

renaming file from globals.py to global_values.py

/trainer-plugin
Anupam Bhatnagar 4 年前
当前提交
5e8aa485
共有 6 个文件被更改,包括 11 次插入9 次删除
  1. 4
      ml-agents/mlagents/trainers/learn.py
  2. 2
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 6
      ml-agents/mlagents/trainers/saver/tf_saver.py
  4. 4
      ml-agents/mlagents/trainers/stats.py
  5. 4
      ml-agents/mlagents/trainers/trainer_controller.py
  6. 0
      /ml-agents/mlagents/tf_utils/global_values.py

4
ml-agents/mlagents/trainers/learn.py


logger.info(f"The following plugins are available {discovered_plugins}")
new_initializers = set(get_all_subclasses(Initializer))
if len(new_initializers) == 1:
if len(new_initializers) <= 0:
return
elif len(new_initializers) == 1:
# load the initializer
logger.info("Registering new initializer")
distributed_init = list(new_initializers)[0]()

2
ml-agents/mlagents/trainers/policy/tf_policy.py


GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.tf_utils.globals import get_rank
from mlagents.tf_utils.global_values import get_rank
logger = get_logger(__name__)

6
ml-agents/mlagents/trainers/saver/tf_saver.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers import __version__
from mlagents.tf_utils.globals import get_rank
from mlagents.tf_utils import global_values
logger = get_logger(__name__)

self.graph = None
self.sess = None
self.tf_saver = None
self.rank = get_rank()
self.rank = global_values.get_rank()
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
if isinstance(module, TFPolicy):

self._load_graph(policy, self.model_path, reset_global_steps=reset_steps)
else:
policy.initialize()
TFPolicy.broadcast_global_variables(0)
TFPolicy.broadcast_global_variables
def _load_graph(
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False

4
ml-agents/mlagents/trainers/stats.py


from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from mlagents.tf_utils import tf, generate_session_config
from mlagents.tf_utils.globals import get_rank
from mlagents.tf_utils import global_values
logger = get_logger(__name__)

# If self-play, we want to print ELO as well as reward
self.self_play = False
self.self_play_team = -1
self.rank = get_rank()
self.rank = global_values.get_rank()
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

4
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.tf_utils.globals import get_rank
from mlagents.tf_utils import global_values
class TrainerController:

self.kill_trainers = False
np.random.seed(training_seed)
tf.set_random_seed(training_seed)
self.rank = get_rank()
self.rank = global_values.get_rank()
@timed
def _save_models(self):

/ml-agents/mlagents/tf_utils/globals.py → /ml-agents/mlagents/tf_utils/global_values.py

正在加载...
取消
保存