浏览代码

Convert checkpoints to .nn format

Fixed style

Fixed more style

Nit changes

Fixed signature

Convert checkpoints to .nn format

Fixed style

Nit changes

Fixed tests, checkpoint management and style

Check checkpoint management

Modify statement on artifacts

Nit changes

Fixed signature

Nit changes

Fixed signature

Fixed tests, checkpoint management and style

Check checkpoint management

Modify statement on artifacts
/develop/checkout-conversion-rebase
Jonathan Harper 5 年前
当前提交
80127232
共有 9 个文件被更改,包括 165 次插入13 次删除
  1. 4
      docs/Training-ML-Agents.md
  2. 17
      ml-agents/mlagents/model_serialization.py
  3. 4
      ml-agents/mlagents/trainers/ghost/trainer.py
  4. 9
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  5. 54
      ml-agents/mlagents/trainers/tests/test_training_status.py
  6. 2
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  7. 33
      ml-agents/mlagents/trainers/trainer/trainer.py
  8. 1
      ml-agents/mlagents/trainers/trainer_controller.py
  9. 54
      ml-agents/mlagents/trainers/training_status.py

4
docs/Training-ML-Agents.md


blocks. See [Profiling in Python](Profiling-Python.md) for more information
on the timers generated.
These artifacts (except the `.nn` file) are updated throughout the training
process and finalized when training completes or is interrupted.
These artifacts are updated throughout the training
process and finalized when training is completed or is interrupted.
#### Stopping and Resuming Training

17
ml-agents/mlagents/model_serialization.py


class SerializationSettings(NamedTuple):
model_path: str
brain_name: str
checkpoint_path: str = ""
convert_to_barracuda: bool = True
convert_to_onnx: bool = True
onnx_opset: int = 9

settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
settings: SerializationSettings,
graph: tf.Graph,
sess: tf.Session,
is_checkpoint: bool = False,
) -> None:
"""
Exports latest saved model to .nn format for Unity embedding.

# Convert to barracuda
if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
if is_checkpoint:
tf2bc.convert(
frozen_graph_def_path,
os.path.join(settings.model_path, f"{settings.checkpoint_path}.nn"),
)
logger.info(f"Exported {settings.checkpoint_path}.nn file")
else:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:

4
ml-agents/mlagents/trainers/ghost/trainer.py


brain_name = parsed_behavior_id.brain_name
self.trainer.save_model(brain_name)
def export_model(self, name_behavior_id: str) -> None:
def export_model(self, name_behavior_id: str, is_checkpoint: bool = False) -> None:
self.trainer.export_model(brain_name)
self.trainer.export_model(brain_name, is_checkpoint)
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters

9
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


assert len(arr) == 0
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.export_model")
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
def test_advance(mocked_clear_update_buffer):
def test_advance(mocked_clear_update_buffer, mocked_save_model, mocked_export_model):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")

# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0
assert mocked_save_model.call_count == mocked_export_model.call_count
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.export_model")
def test_summary_checkpoint(mock_write_summary, mock_save_model):
def test_summary_checkpoint(mock_write_summary, mock_save_model, mock_export_model):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")

)
]
mock_save_model.assert_has_calls(calls, any_order=True)
assert mock_save_model.call_count == mock_export_model.call_count

54
ml-agents/mlagents/trainers/tests/test_training_status.py


assert unknown_category is None
assert unknown_key is None
# Test checkpoint info
class CheckpointTypes(Enum):
CHECKPOINTS = "checkpoints"
FINAL_MODEL = "final_model_path"
check_checkpoints = GlobalTrainingStatus.saved_state[
CheckpointTypes.CHECKPOINTS.value
]
assert check_checkpoints is not None
final_model = GlobalTrainingStatus.saved_state[CheckpointTypes.FINAL_MODEL.value]
assert final_model is not None
def test_model_management(tmpdir):
results_path = os.path.join(tmpdir, "results")
brain_name = "Mock_brain"
final_model_path = os.path.join(results_path, brain_name)
test_checkpoint_list = [
{"steps": 1, "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn")},
{"steps": 2, "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn")},
{"steps": 3, "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn")},
]
GlobalTrainingStatus.set_parameter_state(
brain_name, StatusType.CHECKPOINT, test_checkpoint_list
)
new_checkpoint_4 = {
"steps": 4,
"file_path": os.path.join(final_model_path, f"{brain_name}-4.nn"),
}
GlobalTrainingStatus.track_checkpoint_info(brain_name, new_checkpoint_4, 4)
assert (
len(GlobalTrainingStatus.saved_state[brain_name][StatusType.CHECKPOINT.value])
== 4
)
new_checkpoint_5 = {
"steps": 5,
"file_path": os.path.join(final_model_path, f"{brain_name}-5.nn"),
}
GlobalTrainingStatus.track_checkpoint_info(brain_name, new_checkpoint_5, 4)
assert (
len(GlobalTrainingStatus.saved_state[brain_name][StatusType.CHECKPOINT.value])
== 4
)
final_model_path = f"{final_model_path}.nn"
GlobalTrainingStatus.track_final_model_info(brain_name, final_model_path, 4)
assert (
len(GlobalTrainingStatus.saved_state[brain_name][StatusType.CHECKPOINT.value])
== 3
)
class StatsMetaDataTest(unittest.TestCase):
def test_metadata_compare(self):

2
ml-agents/mlagents/trainers/trainer/rl_trainer.py


if step_after_process >= self._next_save_step and self.get_step != 0:
logger.info(f"Checkpointing model for {self.brain_name}.")
self.save_model(self.brain_name)
logger.info(f"Exporting a checkpoint model for {self.brain_name}.")
self.export_model(self.brain_name, is_checkpoint=True)
def advance(self) -> None:
"""

33
ml-agents/mlagents/trainers/trainer/trainer.py


# # Unity ML-Agents Toolkit
from typing import List, Deque
from typing import List, Deque, Union, Dict
import os
import abc
from collections import deque

from mlagents.trainers.policy import Policy
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.training_status import GlobalTrainingStatus
logger = get_logger(__name__)

"""
self.get_policy(name_behavior_id).save_model(self.get_step)
def export_model(self, name_behavior_id: str) -> None:
def export_model(self, name_behavior_id: str, is_checkpoint: bool = False) -> None:
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)
if is_checkpoint:
checkpoint_path = f"{name_behavior_id}-{self.get_step}"
settings = SerializationSettings(
policy.model_path, policy.brain.brain_name, checkpoint_path
)
new_checkpoint: Dict[str, Union[int, str]] = {}
# Store steps and file_path
new_checkpoint["steps"] = int(self.get_step)
new_checkpoint["file_path"] = os.path.join(
settings.model_path, f"{settings.checkpoint_path}.nn"
)
# Record checkpoint information
GlobalTrainingStatus.track_checkpoint_info(
name_behavior_id, new_checkpoint, policy.keep_checkpoints
)
else:
# Extracting brain name for consistent name_behavior_id
brain_name = policy.model_path.split("/")[-1]
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
# Record final model information
GlobalTrainingStatus.track_final_model_info(
brain_name, f"{settings.model_path}.nn", policy.keep_checkpoints
)
export_policy_model(settings, policy.graph, policy.sess, is_checkpoint)
@abc.abstractmethod
def end_episode(self):

1
ml-agents/mlagents/trainers/trainer_controller.py


"Learning was interrupted. Please wait while the graph is generated."
)
self._save_model()
self._export_graph()
def _export_graph(self):
"""

54
ml-agents/mlagents/trainers/training_status.py


from typing import Dict, Any
import os
from enum import Enum
from collections import defaultdict
import json

class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
CHECKPOINT = "checkpoint"
STEPS = "steps"
FINAL_PATH = "final_model_path"
@attr.s(auto_attribs=True)

:param value: The value.
"""
GlobalTrainingStatus.saved_state[category][key.value] = value
@staticmethod
def append_to_parameter_state(category: str, key: StatusType, value: Any) -> None:
"""
Appends an arbitrary-named parameter in the global saved state.
:param category: The category (usually behavior name) of the parameter.
:param key: The parameter, e.g. lesson number.
:param value: The value.
"""
GlobalTrainingStatus.saved_state[category][key.value].append(value)
@staticmethod
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Removed checkpoint model {file_path}.")
else:
logger.info(f"Checkpoint at {file_path} could not be found.")
return
@staticmethod
def manage_checkpoint_list(category: str, keep_checkpoints: int) -> None:
key = StatusType.CHECKPOINT.value
if key not in GlobalTrainingStatus.saved_state[category]:
GlobalTrainingStatus.saved_state[category][key] = []
checkpoint_list = GlobalTrainingStatus.saved_state[category][key]
num_checkpoints = len(checkpoint_list)
while num_checkpoints >= keep_checkpoints:
if keep_checkpoints <= 0:
break
GlobalTrainingStatus.remove_checkpoint(checkpoint_list.pop(0))
num_checkpoints = len(checkpoint_list)
return
@staticmethod
def track_checkpoint_info(category: str, value: Any, keep_checkpoints: int) -> None:
GlobalTrainingStatus.manage_checkpoint_list(category, keep_checkpoints)
GlobalTrainingStatus.append_to_parameter_state(
category, StatusType.CHECKPOINT, value
)
return
@staticmethod
def track_final_model_info(
category: str, value: str, keep_checkpoints: int
) -> None:
GlobalTrainingStatus.manage_checkpoint_list(category, keep_checkpoints)
GlobalTrainingStatus.set_parameter_state(category, StatusType.FINAL_PATH, value)
return
@staticmethod
def get_parameter_state(category: str, key: StatusType) -> Any:

正在加载...
取消
保存