浏览代码

Move memory validation to settings

/MLA-1734-demo-provider
Ervin Teng 5 年前
当前提交
510583d2
共有 3 个文件被更改,包括 16 次插入21 次删除
  1. 18
      ml-agents/mlagents/trainers/policy/tf_policy.py
  2. 17
      ml-agents/mlagents/trainers/settings.py
  3. 2
      ml-agents/mlagents/trainers/trainer/trainer.py

18
ml-agents/mlagents/trainers/policy/tf_policy.py


if self.network_settings.memory is not None:
self.m_size = self.network_settings.memory.memory_size
self.sequence_length = self.network_settings.memory.sequence_length
if self.m_size == 0:
raise UnityPolicyException(
"The memory size for brain {0} is 0 even "
"though the trainer uses recurrent.".format(brain.brain_name)
)
elif self.m_size % 2 != 0:
raise UnityPolicyException(
"The memory size for brain {0} is {1} "
"but it must be divisible by 2.".format(
brain.brain_name, self.m_size
)
)
self._initialize_tensorflow_references()
self.load = load

def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
logger.info(
"Loading model for brain {} from {}.".format(
self.brain.brain_name, model_path
)
)
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(

17
ml-agents/mlagents/trainers/settings.py


@attr.s(auto_attribs=True)
class NetworkSettings:
@attr.s(auto_attribs=True)
@attr.s
sequence_length: int = 64
memory_size: int = 128
sequence_length: int = attr.ib(default=64)
memory_size: int = attr.ib(default=128)
@memory_size.validator
def _check_valid_memory_size(self, attribute, value):
if value <= 0:
raise TrainerConfigError(
"When using a recurrent network, memory size must be greater than 0."
)
elif value % 2 != 0:
raise TrainerConfigError(
"When using a recurrent network, memory size must be divisible by 2."
)
normalize: bool = False
hidden_units: int = 128

2
ml-agents/mlagents/trainers/trainer/trainer.py


Exports the model
"""
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
settings = SerializationSettings(policy.model_path, self.brain_name)
export_policy_model(settings, policy.graph, policy.sess)
@abc.abstractmethod

正在加载...
取消
保存