浏览代码

[bug-fix] Fix issue with SAC updating too much on resume (#4038)

/MLA-1734-demo-provider
GitHub 5 年前
当前提交
a7323393
共有 3 个文件被更改,包括 31 次插入7 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 20
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 16
      ml-agents/mlagents/trainers/tests/test_sac.py

2
com.unity.ml-agents/CHANGELOG.md


- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,
a warning will be thrown. (#4035)
### Bug Fixes
- Fixed an issue where SAC would perform too many model updates when resuming from a
checkpoint, and too few when using `buffer_init_steps`. (#4038)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

20
ml-agents/mlagents/trainers/sac/trainer.py


)
self.step = 0
# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0
self.update_steps = max(1, self.hyperparameters.buffer_init_steps)
self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps)
# Don't divide by zero
self.update_steps = 1
self.reward_signal_update_steps = 1
self.steps_per_update = self.hyperparameters.steps_per_update
self.reward_signal_steps_per_update = (

)
batch_update_stats: Dict[str, list] = defaultdict(list)
while self.step / self.update_steps > self.steps_per_update:
while (
self.step - self.hyperparameters.buffer_init_steps
) / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:

)
batch_update_stats: Dict[str, list] = defaultdict(list)
while (
self.step / self.reward_signal_update_steps
> self.reward_signal_steps_per_update
):
self.step - self.hyperparameters.buffer_init_steps
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():

self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before
self.update_steps = int(max(1, self.step / self.steps_per_update))
self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:

16
ml-agents/mlagents/trainers/tests/test_sac.py


discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config.hyperparameters.steps_per_update = 20
dummy_config.hyperparameters.reward_signal_steps_per_update = 20
dummy_config.hyperparameters.buffer_init_steps = 0
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = trainer.create_policy(brain_params.brain_name, brain_params)

trainer.advance()
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get_nowait()
# Call add_policy and check that we update the correct number of times.
# This is to emulate a load from checkpoint.
policy = trainer.create_policy(brain_params.brain_name, brain_params)
policy.get_current_step = lambda: 200
trainer.add_policy(brain_params.brain_name, policy)
trainer.optimizer.update = mock.Mock()
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}
trainer.optimizer.update.return_value = {}
trajectory_queue.put(trajectory)
trainer.advance()
# Make sure we did exactly 1 update
assert trainer.optimizer.update.call_count == 1
assert trainer.optimizer.update_reward_signals.call_count == 1
if __name__ == "__main__":
正在加载...
取消
保存