浏览代码

removing horovod from tf policy

/MLA-1734-demo-provider
Anupam Bhatnagar 4 年前
当前提交
d3e8f124
共有 2 个文件被更改,包括 5 次插入8 次删除
  1. 4
      ml-agents/mlagents/tf_utils/globals.py
  2. 9
      ml-agents/mlagents/trainers/policy/tf_policy.py

4
ml-agents/mlagents/tf_utils/globals.py


def get_rank() -> Optional[int]:
return _rank
def broadcast_variables() -> bool:
return True if _rank is not None else False

9
ml-agents/mlagents/trainers/policy/tf_policy.py


from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Callable
import numpy as np
from distutils.version import LooseVersion

GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.tf_utils.globals import get_rank, broadcast_variables
from mlagents.tf_utils.globals import get_rank
logger = get_logger(__name__)

Contains a learning model, and the necessary
functions to save/load models and create the input placeholders.
"""
broadcast_global_variables: Callable[[int], None] = lambda x: None
def __init__(
self,

else:
self._initialize_graph()
# broadcast initial weights from worker-0
if broadcast_variables():
self.sess.run(hvd.broadcast_global_variables(0))
TFPolicy.broadcast_global_variables(0)
def get_weights(self):
with self.graph.as_default():

正在加载...
取消
保存