|
|
|
|
|
|
if self.network_settings.memory is not None: |
|
|
|
self.m_size = self.network_settings.memory.memory_size |
|
|
|
self.sequence_length = self.network_settings.memory.sequence_length |
|
|
|
if self.m_size == 0: |
|
|
|
raise UnityPolicyException( |
|
|
|
"The memory size for brain {0} is 0 even " |
|
|
|
"though the trainer uses recurrent.".format(brain.brain_name) |
|
|
|
) |
|
|
|
elif self.m_size % 2 != 0: |
|
|
|
raise UnityPolicyException( |
|
|
|
"The memory size for brain {0} is {1} " |
|
|
|
"but it must be divisible by 2.".format( |
|
|
|
brain.brain_name, self.m_size |
|
|
|
) |
|
|
|
) |
|
|
|
self._initialize_tensorflow_references() |
|
|
|
self.load = load |
|
|
|
|
|
|
|
|
|
|
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None: |
|
|
|
with self.graph.as_default(): |
|
|
|
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) |
|
|
|
logger.info( |
|
|
|
"Loading model for brain {} from {}.".format( |
|
|
|
self.brain.brain_name, model_path |
|
|
|
) |
|
|
|
) |
|
|
|
logger.info(f"Loading model from {model_path}.") |
|
|
|
ckpt = tf.train.get_checkpoint_state(model_path) |
|
|
|
if ckpt is None: |
|
|
|
raise UnityPolicyException( |
|
|
|