|
|
|
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
) |
|
|
|
from mlagents.tf_utils.globals import get_rank, broadcast_variables |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
self.grads = None |
|
|
|
self.update_batch: Optional[tf.Operation] = None |
|
|
|
self.trainable_variables: List[tf.Variable] = [] |
|
|
|
self.rank = get_rank() |
|
|
|
if create_tf_graph: |
|
|
|
self.create_tf_graph() |
|
|
|
|
|
|
|
|
|
|
self._load_graph(self.model_path, reset_global_steps=reset_steps) |
|
|
|
else: |
|
|
|
self._initialize_graph() |
|
|
|
# broadcast initial weights from worker-0 |
|
|
|
if broadcast_variables(): |
|
|
|
self.sess.run(hvd.broadcast_global_variables(0)) |
|
|
|
|
|
|
|
def get_weights(self): |
|
|
|
with self.graph.as_default(): |
|
|
|
|
|
|
:param output_filepath: path (without suffix) for the model file(s) |
|
|
|
:param settings: SerializationSettings for how to save the model. |
|
|
|
""" |
|
|
|
# save model if there is only one worker or |
|
|
|
# only on worker-0 if there are multiple workers |
|
|
|
if self.rank is not None and self.rank != 0: |
|
|
|
return |
|
|
|
export_policy_model(output_filepath, settings, self.graph, self.sess) |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|
|
|