import os import unittest import json from enum import Enum import time from mlagents.trainers.training_status import ( StatusType, StatusMetaData, GlobalTrainingStatus, ) from mlagents.trainers.policy.checkpoint_manager import ( ModelCheckpointManager, ModelCheckpoint, ) def test_globaltrainingstatus(tmpdir): path_dir = os.path.join(tmpdir, "test.json") GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3) GlobalTrainingStatus.save_state(path_dir) with open(path_dir) as fp: test_json = json.load(fp) assert "Category1" in test_json assert StatusType.LESSON_NUM.value in test_json["Category1"] assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3 assert "metadata" in test_json GlobalTrainingStatus.load_state(path_dir) restored_val = GlobalTrainingStatus.get_parameter_state( "Category1", StatusType.LESSON_NUM ) assert restored_val == 3 # Test unknown categories and status types (keys) unknown_category = GlobalTrainingStatus.get_parameter_state( "Category3", StatusType.LESSON_NUM ) class FakeStatusType(Enum): NOTAREALKEY = "notarealkey" unknown_key = GlobalTrainingStatus.get_parameter_state( "Category1", FakeStatusType.NOTAREALKEY ) assert unknown_category is None assert unknown_key is None def test_model_management(tmpdir): results_path = os.path.join(tmpdir, "results") brain_name = "Mock_brain" final_model_path = os.path.join(results_path, brain_name) test_checkpoint_list = [ { "steps": 1, "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"), "reward": 1.312, "creation_time": time.time(), }, { "steps": 2, "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"), "reward": 1.912, "creation_time": time.time(), }, { "steps": 3, "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"), "reward": 2.312, "creation_time": time.time(), }, ] GlobalTrainingStatus.set_parameter_state( brain_name, StatusType.CHECKPOINTS, test_checkpoint_list ) new_checkpoint_4 = ModelCheckpoint( 4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time() ) ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 new_checkpoint_5 = ModelCheckpoint( 5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time() ) ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 final_model_path = f"{final_model_path}.nn" final_model_time = time.time() current_step = 6 final_model = ModelCheckpoint( current_step, final_model_path, 3.294, final_model_time ) ModelCheckpointManager.track_final_checkpoint(brain_name, final_model) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][ StatusType.CHECKPOINTS.value ] assert check_checkpoints is not None final_model = GlobalTrainingStatus.saved_state[StatusType.FINAL_CHECKPOINT.value] assert final_model is not None class StatsMetaDataTest(unittest.TestCase): def test_metadata_compare(self): # Test write_stats with self.assertLogs("mlagents.trainers", level="WARNING") as cm: default_metadata = StatusMetaData() version_statsmetadata = StatusMetaData(mlagents_version="test") default_metadata.check_compatibility(version_statsmetadata) torch_version_statsmetadata = StatusMetaData(torch_version="test") default_metadata.check_compatibility(torch_version_statsmetadata) # Assert that 2 warnings have been thrown assert len(cm.output) == 2