浏览代码

Fix SAC

/develop/critic-op-lstm-currentmem
Ervin Teng 4 年前
当前提交
bb452ffd
共有 2 个文件被更改,包括 26 次插入12 次删除
  1. 29
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 9
      ml-agents/mlagents/trainers/sac/trainer.py

29
ml-agents/mlagents/trainers/sac/optimizer_torch.py


for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
]
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
value_memories_list = [
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
for i in range(
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
next_memories_list = [
next_value_memories_list = [
batch[BufferKey.MEMORY][i]
batch[BufferKey.CRITIC_MEMORY][i]
offset, len(batch[BufferKey.MEMORY]), self.policy.sequence_length
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
next_memories = torch.stack(next_memories_list).unsqueeze(0)
value_memories = torch.stack(value_memories_list).unsqueeze(0)
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0)
next_memories = None
value_memories = None
next_value_memories = None
torch.zeros_like(next_memories) if next_memories is not None else None
)
v_memories = (
torch.zeros_like(next_memories) if next_memories is not None else None
torch.zeros_like(next_value_memories)
if next_value_memories is not None
else None
)
# Copy normalizers from policy

sequence_length=self.policy.sequence_length,
)
value_estimates, _ = self.value_network.critic_pass(
current_obs, v_memories, sequence_length=self.policy.sequence_length
current_obs, value_memories, sequence_length=self.policy.sequence_length
)
cont_sampled_actions = sampled_actions.continuous_tensor

with torch.no_grad():
target_values, _ = self.target_network(
next_obs,
memories=next_memories,
memories=next_value_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)

9
ml-agents/mlagents/trainers/sac/trainer.py


self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# Get all value estimates for reporting purposes
value_estimates, _, _ = self.optimizer.get_trajectory_value_estimates(
(
value_estimates,
_,
value_memories,
) = self.optimizer.get_trajectory_value_estimates(
if value_memories is not None:
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories)
for name, v in value_estimates.items():
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value",

正在加载...
取消
保存