浏览代码

fix NNCheckpointManager for Torch

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
09c22679
共有 2 个文件被更改,包括 16 次插入8 次删除
  1. 18
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  2. 6
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

18
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.settings import TrainerSettings, FrameworkType
# Add concrete implementations of abstract methods

super()._process_trajectory(trajectory)
def create_rl_trainer():
def create_rl_trainer(framework=FrameworkType.TENSORFLOW):
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
TrainerSettings(
max_steps=100, checkpoint_interval=10, summary_freq=20, framework=framework
),
True,
False,
"mock_model_path",

assert mocked_save_model.call_count == 0
@pytest.mark.parametrize(
"framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"]
)
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
trainer = create_rl_trainer()
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
trainer = create_rl_trainer(framework)
mock_policy = mock.Mock()
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")

)
calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range]
trainer.saver.save_checkpoint.assert_has_calls(calls, any_order=True)
export_ext = "nn" if trainer.framework == FrameworkType.TENSORFLOW else "onnx"
add_checkpoint_calls = [
mock.call(

f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.nn",
f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.{export_ext}",
None,
mock.ANY,
),

6
ml-agents/mlagents/trainers/trainer/rl_trainer.py


"Trainer has multiple policies, but default behavior only saves the first."
)
checkpoint_path = self.saver.save_checkpoint(self.brain_name, self.step)
export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
f"{checkpoint_path}.nn",
f"{checkpoint_path}.{export_ext}",
self._policy_mean_reward(),
time.time(),
)

model_checkpoint = self._checkpoint()
self.saver.copy_final_model(model_checkpoint.file_path)
export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
model_checkpoint, file_path=f"{self.saver.model_path}.nn"
model_checkpoint, file_path=f"{self.saver.model_path}.{export_ext}"
)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

正在加载...
取消
保存