浏览代码

change brain_name to behavior_name

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
1f7b57e0
共有 4 个文件被更改,包括 24 次插入24 次删除
  1. 10
      ml-agents/mlagents/trainers/saver/saver.py
  2. 10
      ml-agents/mlagents/trainers/saver/tf_saver.py
  3. 8
      ml-agents/mlagents/trainers/saver/torch_saver.py
  4. 20
      ml-agents/mlagents/trainers/tf/model_serialization.py

10
ml-agents/mlagents/trainers/saver/saver.py


pass
@abc.abstractmethod
def save_checkpoint(self, brain_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> str:
:param brain_name: Brain name of brain to be trained
:param behavior_name: Behavior name of behavior to be trained
def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
Saves the serialized model, given a path and brain name.
Saves the serialized model, given a path and behavior name.
:param brain_name: Brain name of brain to be trained.
:param behavior_name: Behavior name of behavior to be trained.
"""
pass

10
ml-agents/mlagents/trainers/saver/tf_saver.py


with self.policy.graph.as_default():
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def save_checkpoint(self, brain_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
def save_checkpoint(self, behavior_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
# Save the TF checkpoint and graph definition
if self.graph:
with self.graph.as_default():

self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
self.export(checkpoint_path, behavior_name)
def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
self.model_path, output_filepath, brain_name, self.graph, self.sess
self.model_path, output_filepath, behavior_name, self.graph, self.sess
)
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:

8
ml-agents/mlagents/trainers/saver/torch_saver.py


self.policy = module
self.exporter = ModelSerializer(self.policy)
def save_checkpoint(self, brain_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
self.export(checkpoint_path, brain_name)
self.export(checkpoint_path, behavior_name)
def export(self, output_filepath: str, brain_name: str) -> None:
def export(self, output_filepath: str, behavior_name: str) -> None:
if self.exporter is not None:
self.exporter.export_policy_model(output_filepath)

20
ml-agents/mlagents/trainers/tf/model_serialization.py


def export_policy_model(
model_path: str,
output_filepath: str,
brain_name: str,
behavior_name: str,
graph: tf.Graph,
sess: tf.Session,
) -> None:

:param output_filepath: file path to output the model (without file suffix)
:param brain_name: brain name of the trained model
:param behavior_name: behavior name of the trained model
frozen_graph_def = _make_frozen_graph(brain_name, graph, sess)
frozen_graph_def = _make_frozen_graph(behavior_name, graph, sess)
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
# Save frozen graph

if ONNX_EXPORT_ENABLED:
if SerializationSettings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(brain_name, frozen_graph_def)
onnx_graph = convert_frozen_to_onnx(behavior_name, frozen_graph_def)
onnx_output_path = f"{output_filepath}.onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())

def _make_frozen_graph(
brain_name: str, graph: tf.Graph, sess: tf.Session
behavior_name: str, graph: tf.Graph, sess: tf.Session
target_nodes = ",".join(_process_graph(brain_name, graph))
target_nodes = ",".join(_process_graph(behavior_name, graph))
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, target_nodes.replace(" ", "").split(",")

def convert_frozen_to_onnx(brain_name: str, frozen_graph_def: tf.GraphDef) -> Any:
def convert_frozen_to_onnx(behavior_name: str, frozen_graph_def: tf.GraphDef) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
inputs = _get_input_node_names(frozen_graph_def)

)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(brain_name)
model_proto = onnx_graph.make_model(behavior_name)
return model_proto

return names
def _process_graph(brain_name: str, graph: tf.Graph) -> List[str]:
def _process_graph(behavior_name: str, graph: tf.Graph) -> List[str]:
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names

logger.info("List of nodes to export for brain :" + brain_name)
logger.info("List of nodes to export for behavior :" + behavior_name)
for n in nodes:
logger.info("\t" + n)
return nodes

正在加载...
取消
保存