浏览代码

[bug-fix] Fix memory leak when using LSTMs (#5048)

* Detach memory before storing

* Add test

* Evaluate with no_grad
/develop/gail-srl-hack
GitHub 4 年前
当前提交
d24b0966
共有 2 个文件被更改,包括 19 次插入11 次删除
  1. 25
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 5
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py

25
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


# If we're using LSTM, we want to get all the intermediate memories.
all_next_memories: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
(
value_estimates,
all_next_memories,
next_memory,
) = self._evaluate_by_sequence(current_obs, memory)
else:
value_estimates, next_memory = self.critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)
# To prevent memory leak and improve performance, evaluate with no_grad.
with torch.no_grad():
if self.policy.use_recurrent:
(
value_estimates,
all_next_memories,
next_memory,
) = self._evaluate_by_sequence(current_obs, memory)
else:
value_estimates, next_memory = self.critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)
# Store the memory for the next trajectory
# Store the memory for the next trajectory. This should NOT have a gradient.
self.critic_memory_dict[agent_id] = next_memory
next_value_estimate, _ = self.critic.critic_pass(

5
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
)
if rnn:
# Check that memories don't have a Torch gradient
for mem in optimizer.critic_memory_dict.values():
assert not mem.requires_grad
for key, val in run_out.items():
assert type(key) is str
assert len(val) == 15

正在加载...
取消
保存