浏览代码

add torch no_grad to coma LSTM value computation

/develop/coma2/fixgroup
Andrew Cohen 4 年前
当前提交
21d7ab85
共有 1 个文件被更改,包括 27 次插入22 次删除
  1. 49
      ml-agents/mlagents/trainers/coma/optimizer_torch.py

49
ml-agents/mlagents/trainers/coma/optimizer_torch.py


all_obs = [current_obs] + team_obs if team_obs is not None else [current_obs]
all_next_value_mem: Optional[AgentBufferField] = None
all_next_baseline_mem: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
(
value_estimates,
baseline_estimates,
all_next_value_mem,
all_next_baseline_mem,
next_value_mem,
next_baseline_mem,
) = self._evaluate_by_sequence_team(
current_obs, team_obs, team_actions, _init_value_mem, _init_baseline_mem
)
else:
value_estimates, next_value_mem = self.critic.critic_pass(
all_obs, _init_value_mem, sequence_length=batch.num_experiences
)
with torch.no_grad():
if self.policy.use_recurrent:
(
value_estimates,
baseline_estimates,
all_next_value_mem,
all_next_baseline_mem,
next_value_mem,
next_baseline_mem,
) = self._evaluate_by_sequence_team(
current_obs,
team_obs,
team_actions,
_init_value_mem,
_init_baseline_mem,
)
else:
value_estimates, next_value_mem = self.critic.critic_pass(
all_obs, _init_value_mem, sequence_length=batch.num_experiences
)
baseline_estimates, next_baseline_mem = self.critic.baseline(
[current_obs],
team_obs,
team_actions,
_init_baseline_mem,
sequence_length=batch.num_experiences,
)
baseline_estimates, next_baseline_mem = self.critic.baseline(
[current_obs],
team_obs,
team_actions,
_init_baseline_mem,
sequence_length=batch.num_experiences,
)
# Store the memory for the next trajectory
self.value_memory_dict[agent_id] = next_value_mem
self.baseline_memory_dict[agent_id] = next_baseline_mem

正在加载...
取消
保存