浏览代码

adding rank to ml-agents

/MLA-1734-demo-provider
Anupam Bhatnagar 5 年前
当前提交
dbd5dc04
共有 2 个文件被更改,包括 42 次插入14 次删除
  1. 48
      ml-agents/mlagents/trainers/stats.py
  2. 8
      ml-agents/mlagents/trainers/trainer_controller.py

48
ml-agents/mlagents/trainers/stats.py


from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from mlagents.tf_utils import tf, generate_session_config
from mlagents.tf_utils.globals import get_rank
logger = get_logger(__name__)

# If self-play, we want to print ELO as well as reward
self.self_play = False
self.self_play_team = -1
self.rank = get_rank()
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

stats_summary = stats_summary = values["Is Training"]
stats_summary = values["Is Training"]
logger.info(
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
if self.rank is not None:
logger.info(
"Rank: {}."
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
self.rank(),
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
else:
logger.info(
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
if self.self_play and "Self-play/ELO" in values:
elo_stats = values["Self-play/ELO"]
logger.info(f"{category} ELO: {elo_stats.mean:0.3f}. ")

8
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.tf_utils.globals import get_rank
class TrainerController:

self.kill_trainers = False
np.random.seed(training_seed)
tf.set_random_seed(training_seed)
self.rank = get_rank()
@timed
def _save_models(self):

if self.rank is not None and self.rank != 0:
return
for brain_name in self.trainers.keys():
self.trainers[brain_name].save_model()
self.logger.info("Saved Model")

"""
Saves models for all trainers.
"""
if self.rank is not None and self.rank != 0:
return
for brain_name in self.trainers.keys():
self.trainers[brain_name].save_model()

正在加载...
取消
保存