|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Callable |
|
|
|
import numpy as np |
|
|
|
from distutils.version import LooseVersion |
|
|
|
|
|
|
|
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
) |
|
|
|
from mlagents.tf_utils.globals import get_rank, broadcast_variables |
|
|
|
from mlagents.tf_utils.globals import get_rank |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
Contains a learning model, and the necessary |
|
|
|
functions to save/load models and create the input placeholders. |
|
|
|
""" |
|
|
|
|
|
|
|
broadcast_global_variables: Callable[[int], None] = lambda x: None |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
else: |
|
|
|
self._initialize_graph() |
|
|
|
# broadcast initial weights from worker-0 |
|
|
|
if broadcast_variables(): |
|
|
|
self.sess.run(hvd.broadcast_global_variables(0)) |
|
|
|
TFPolicy.broadcast_global_variables(0) |
|
|
|
|
|
|
|
def get_weights(self): |
|
|
|
with self.graph.as_default(): |
|
|
|