浏览代码

Fix extra summary being written when loading from checkpoint (#3272)

* Load next summary properly

* Add tests for add_policy and get_policy
/asymm-envs
GitHub 5 年前
当前提交
329b23e0
共有 5 个文件被更改,包括 61 次插入6 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 23
      ml-agents/mlagents/trainers/tests/test_ppo.py
  4. 24
      ml-agents/mlagents/trainers/tests/test_sac.py
  5. 16
      ml-agents/mlagents/trainers/trainer.py

2
ml-agents/mlagents/trainers/ppo/trainer.py


if not isinstance(policy, PPOPolicy):
raise RuntimeError("Non-PPOPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
# Needed to resume loads properly
self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""

2
ml-agents/mlagents/trainers/sac/trainer.py


if not isinstance(policy, SACPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
# Needed to resume loads properly
self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""

23
ml-agents/mlagents/trainers/tests/test_ppo.py


assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
def test_add_get_policy(dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0", False)
policy = mock.Mock(spec=PPOPolicy)
policy.get_current_step.return_value = 2000
trainer.add_policy(brain_params.brain_name, policy)
assert trainer.get_policy(brain_params.brain_name) == policy
# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(brain_params, policy)
def test_normalization(dummy_config):
brain_params = BrainParameters(
brain_name="test_brain",

24
ml-agents/mlagents/trainers/tests/test_sac.py


import pytest
from unittest import mock
import yaml
import numpy as np

policy = trainer2.create_policy(mock_brain)
trainer2.add_policy(mock_brain.brain_name, policy)
assert trainer2.update_buffer.num_experiences == buffer_len
def test_add_get_policy(dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = mock.Mock(spec=SACPolicy)
policy.get_current_step.return_value = 2000
trainer.add_policy(brain_params.brain_name, policy)
assert trainer.get_policy(brain_params.brain_name) == policy
# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(brain_params, policy)
def test_process_trajectory(dummy_config):

16
ml-agents/mlagents/trainers/trainer.py


self.step: int = 0
self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_update_step = self.summary_freq
self.next_summary_step = self.summary_freq
def _check_param_keys(self):
for k in self.param_keys:

:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_update_step = self.step + (
self.summary_freq - self.step % self.summary_freq
)
self.next_summary_step = self._get_next_summary_step()
def _get_next_summary_step(self) -> int:
"""
Get the next step count that should result in a summary write.
"""
return self.step + (self.summary_freq - self.step % self.summary_freq)
def save_model(self, name_behavior_id: str) -> None:
"""

write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_update_step and self.get_step != 0:
self._write_summary(self.next_update_step)
if step_after_process >= self.next_summary_step and self.get_step != 0:
self._write_summary(self.next_summary_step)
@abc.abstractmethod
def end_episode(self):

正在加载...
取消
保存