|
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
from distutils.version import LooseVersion |
|
|
|
|
|
|
|
from mlagents.tf_utils import tf |
|
|
|
from mlagents import tf_utils |
|
|
|
from mlagents_envs.exception import UnityException |
|
|
|
|
|
|
from mlagents.trainers.models import ModelUtils |
|
|
|
from mlagents.trainers.settings import TrainerSettings, NetworkSettings |
|
|
|
from mlagents.trainers.brain import BrainParameters |
|
|
|
from mlagents.trainers import __version__ |
|
|
|
|
|
|
|
|
|
|
|
# This is the version number of the inputs and outputs of the model, and |
|
|
|
# determines compatibility with inference in Barracuda. |
|
|
|
MODEL_FORMAT_VERSION = 2 |
|
|
|
|
|
|
|
|
|
|
|
class UnityPolicyException(UnityException): |
|
|
|
|
|
|
:param brain: The corresponding Brain for this policy. |
|
|
|
:param trainer_settings: The trainer parameters. |
|
|
|
""" |
|
|
|
self._version_number_ = 2 |
|
|
|
|
|
|
|
self.m_size = 0 |
|
|
|
self.trainer_settings = trainer_settings |
|
|
|
self.network_settings: NetworkSettings = trainer_settings.network_settings |
|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _convert_version_string(version_string: str) -> Tuple[int, ...]: |
|
|
|
""" |
|
|
|
Converts the version string into a Tuple of ints (major_ver, minor_ver, patch_ver). |
|
|
|
:param version_string: The semantic-versioned version string (X.Y.Z). |
|
|
|
:return: A Tuple containing (major_ver, minor_ver, patch_ver). |
|
|
|
""" |
|
|
|
ver = LooseVersion(version_string) |
|
|
|
return tuple(map(int, ver.version[0:3])) |
|
|
|
|
|
|
|
def _check_model_version(self, version: str) -> None: |
|
|
|
""" |
|
|
|
Checks whether the model being loaded was created with the same version of |
|
|
|
ML-Agents, and throw a warning if not so. |
|
|
|
""" |
|
|
|
if self.version_tensors is not None: |
|
|
|
loaded_ver = tuple( |
|
|
|
num.eval(session=self.sess) for num in self.version_tensors |
|
|
|
) |
|
|
|
if loaded_ver != TFPolicy._convert_version_string(version): |
|
|
|
logger.warning( |
|
|
|
f"The model checkpoint you are loading from was saved with ML-Agents version " |
|
|
|
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents" |
|
|
|
f"version is {version}. Model may not behave properly." |
|
|
|
) |
|
|
|
|
|
|
|
def _initialize_graph(self): |
|
|
|
with self.graph.as_default(): |
|
|
|
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) |
|
|
|
|
|
|
model_path |
|
|
|
) |
|
|
|
) |
|
|
|
self._check_model_version(__version__) |
|
|
|
if reset_global_steps: |
|
|
|
self._set_step(0) |
|
|
|
logger.info( |
|
|
|
|
|
|
self.prev_action: Optional[tf.Tensor] = None |
|
|
|
self.memory_in: Optional[tf.Tensor] = None |
|
|
|
self.memory_out: Optional[tf.Tensor] = None |
|
|
|
self.version_tensors: Optional[Tuple[tf.Tensor, tf.Tensor, tf.Tensor]] = None |
|
|
|
|
|
|
|
def create_input_placeholders(self): |
|
|
|
with self.graph.as_default(): |
|
|
|
|
|
|
trainable=False, |
|
|
|
dtype=tf.int32, |
|
|
|
) |
|
|
|
int_version = TFPolicy._convert_version_string(__version__) |
|
|
|
major_ver_t = tf.Variable( |
|
|
|
int_version[0], |
|
|
|
name="trainer_major_version", |
|
|
|
trainable=False, |
|
|
|
dtype=tf.int32, |
|
|
|
) |
|
|
|
minor_ver_t = tf.Variable( |
|
|
|
int_version[1], |
|
|
|
name="trainer_minor_version", |
|
|
|
trainable=False, |
|
|
|
dtype=tf.int32, |
|
|
|
) |
|
|
|
patch_ver_t = tf.Variable( |
|
|
|
int_version[2], |
|
|
|
name="trainer_patch_version", |
|
|
|
trainable=False, |
|
|
|
dtype=tf.int32, |
|
|
|
) |
|
|
|
self.version_tensors = (major_ver_t, minor_ver_t, patch_ver_t) |
|
|
|
self._version_number_, |
|
|
|
MODEL_FORMAT_VERSION, |
|
|
|
name="version_number", |
|
|
|
trainable=False, |
|
|
|
dtype=tf.int32, |
|
|
|