浏览代码

[skip ci] change self.rank to global_values.get_rank()

/trainer-plugin
Anupam Bhatnagar 5 年前
当前提交
1f60979f
共有 4 个文件被更改,包括 9 次插入10 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/tf_policy.py
  2. 7
      ml-agents/mlagents/trainers/saver/tf_saver.py
  3. 5
      ml-agents/mlagents/trainers/stats.py
  4. 5
      ml-agents/mlagents/trainers/trainer_controller.py

2
ml-agents/mlagents/trainers/policy/tf_policy.py


GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.tf_utils import global_values
logger = get_logger(__name__)

self.grads = None
self.update_batch: Optional[tf.Operation] = None
self.trainable_variables: List[tf.Variable] = []
self.rank = global_values.get_rank()
if create_tf_graph:
self.create_tf_graph()

7
ml-agents/mlagents/trainers/saver/tf_saver.py


self.graph = None
self.sess = None
self.tf_saver = None
self.rank = global_values.get_rank()
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
if isinstance(module, TFPolicy):

def export(self, output_filepath: str, brain_name: str) -> None:
# save model if there is only one worker or
# only on worker-0 if there are multiple workers
if self.policy and self.rank is not None and self.rank != 0:
if (
self.policy
and global_values.get_rank() is not None
and global_values.get_rank() != 0
):
return
export_policy_model(
self.model_path, output_filepath, brain_name, self.graph, self.sess

5
ml-agents/mlagents/trainers/stats.py


# If self-play, we want to print ELO as well as reward
self.self_play = False
self.self_play_team = -1
self.rank = global_values.get_rank()
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

log_info.append(f"Time Elapsed: {elapsed_time:0.3f} s")
if "Environment/Cumulative Reward" in values:
stats_summary = values["Environment/Cumulative Reward"]
if self.rank is not None:
log_info.append(f"Rank: {self.rank}")
if global_values.get_rank() is not None:
log_info.append(f"Rank: {global_values.get_rank()}")
log_info.append(f"Mean Reward: {stats_summary.mean:0.3f}")
log_info.append(f"Std of Reward: {stats_summary.std:0.3f}")

5
ml-agents/mlagents/trainers/trainer_controller.py


self.kill_trainers = False
np.random.seed(training_seed)
tf.set_random_seed(training_seed)
self.rank = global_values.get_rank()
@timed
def _save_models(self):

if self.rank is not None and self.rank != 0:
if global_values.get_rank() is not None and global_values.get_rank() != 0:
return
for brain_name in self.trainers.keys():

"""
Saves models for all trainers.
"""
if self.rank is not None and self.rank != 0:
if global_values.get_rank() is not None and global_values.get_rank() != 0:
return
for brain_name in self.trainers.keys():

正在加载...
取消
保存