浏览代码

add wb

/develop/wb
Ruo-Ping Dong 4 年前
当前提交
fb50b0ec
共有 3 个文件被更改,包括 13 次插入0 次删除
  1. 3
      ml-agents/mlagents/trainers/learn.py
  2. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 8
      ml-agents/mlagents/trainers/stats.py

3
ml-agents/mlagents/trainers/learn.py


)
from mlagents_envs import logging_util
import wandb
logger = logging_util.get_logger(__name__)
TRAINING_STATUS_FILE_NAME = "training_status.json"

:param run_seed: Random seed used for training.
:param run_options: Command line arguments for training.
"""
#wandb.init(project="mlagent-cloud-profiling", sync_tensorboard=True)
with hierarchical_timer("run_training.setup"):
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import TrainerSettings, PPOSettings
from mlagents.trainers.torch.utils import ModelUtils
import wandb
class TorchPPOOptimizer(TorchOptimizer):

+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)
wandb.log({"loss": loss})
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)

8
ml-agents/mlagents/trainers/stats.py


from mlagents.tf_utils import tf, generate_session_config
from mlagents.tf_utils.globals import get_rank
import wandb
logger = get_logger(__name__)

self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
self.base_dir: str = base_dir
self._clear_past_data = clear_past_data
with open('wandb_API', 'r') as f:
api_key = f.readline().strip()
os.environ['WANDB_API_KEY'] = api_key
wandb.init(project="mlagent-cloud-profiling")
wandb.tensorboard.patch(pytorch=True, tensorboardX=True)
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

正在加载...
取消
保存