|
|
|
|
|
|
from mlagents.tf_utils import tf, generate_session_config |
|
|
|
from mlagents.tf_utils.globals import get_rank |
|
|
|
|
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
self.summary_writers: Dict[str, tf.summary.FileWriter] = {} |
|
|
|
self.base_dir: str = base_dir |
|
|
|
self._clear_past_data = clear_past_data |
|
|
|
|
|
|
|
with open('wandb_API', 'r') as f: |
|
|
|
api_key = f.readline().strip() |
|
|
|
os.environ['WANDB_API_KEY'] = api_key |
|
|
|
wandb.init(project="mlagent-cloud-profiling") |
|
|
|
wandb.tensorboard.patch(pytorch=True, tensorboardX=True) |
|
|
|
|
|
|
|
def write_stats( |
|
|
|
self, category: str, values: Dict[str, StatsSummary], step: int |
|
|
|