浏览代码

[skip ci] save model on worker zero only

/MLA-1734-demo-provider
Anupam Bhatnagar 5 年前
当前提交
87bdf353
共有 2 个文件被更改,包括 13 次插入0 次删除
  1. 4
      ml-agents/mlagents/tf_utils/globals.py
  2. 9
      ml-agents/mlagents/trainers/policy/tf_policy.py

4
ml-agents/mlagents/tf_utils/globals.py


def get_rank() -> Optional[int]:
return _rank
def broadcast_variables() -> bool:
return True if _rank is not None else False

9
ml-agents/mlagents/trainers/policy/tf_policy.py


GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.tf_utils.globals import get_rank, broadcast_variables
logger = get_logger(__name__)

self.grads = None
self.update_batch: Optional[tf.Operation] = None
self.trainable_variables: List[tf.Variable] = []
self.rank = get_rank()
if create_tf_graph:
self.create_tf_graph()

self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()
# broadcast initial weights from worker-0
if broadcast_variables():
self.sess.run(hvd.broadcast_global_variables(0))
def get_weights(self):
with self.graph.as_default():

:param output_filepath: path (without suffix) for the model file(s)
:param settings: SerializationSettings for how to save the model.
"""
# save model if there is only one worker or
# only on worker-0 if there are multiple workers
if self.rank is not None and self.rank != 0:
return
export_policy_model(output_filepath, settings, self.graph, self.sess)
def update_normalization(self, vector_obs: np.ndarray) -> None:

正在加载...
取消
保存