浏览代码

[bug-fix] Use correct memories for LSTM SAC (#5228)

* Use correct memories for LSTM SAC

* Add some comments

(cherry picked from commit 707730256a6797336ba749f05f7dbf10dadd8126)
/release_16_branch
Ervin Teng 3 年前
当前提交
9e2e2626
共有 1 个文件被更改,包括 13 次插入14 次删除
  1. 27
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

27
ml-agents/mlagents/trainers/sac/optimizer_torch.py


0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
offset = 1 if self.policy.sequence_length > 1 else 0
next_value_memories_list = [
ModelUtils.list_to_tensor(
batch[BufferKey.CRITIC_MEMORY][i]
) # only pass value part of memory to target network
for i in range(
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0)
next_value_memories = None
torch.zeros_like(next_value_memories)
if next_value_memories is not None
else None
torch.zeros_like(value_memories) if value_memories is not None else None
)
# Copy normalizers from policy

q1_stream, q2_stream = q1_out, q2_out
with torch.no_grad():
# Since we didn't record the next value memories, evaluate one step in the critic to
# get them.
if value_memories is not None:
# Get the first observation in each sequence
just_first_obs = [
_obs[:: self.policy.sequence_length] for _obs in current_obs
]
_, next_value_memories = self._critic.critic_pass(
just_first_obs, value_memories, sequence_length=1
)
else:
next_value_memories = None
target_values, _ = self.target_network(
next_obs,
memories=next_value_memories,

正在加载...
取消
保存