浏览代码

[refactor] Don't compute grad for q2_p in SAC Optimizer (#4509)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
05fc088d
共有 1 个文件被更改,包括 49 次插入21 次删除
  1. 70
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

70
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents_envs.timers import timed
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings, SACSettings
from contextlib import ExitStack
EPSILON = 1e-6 # Small value to avoid divide by zero

actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
q1_grad: bool = True,
q2_grad: bool = True,
q1_out, _ = self.q1_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
q2_out, _ = self.q2_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
"""
Performs a forward pass on the value network, which consists of a Q1 and Q2
network. Optionally does not evaluate gradients for either the Q1, Q2, or both.
:param vec_inputs: List of vector observation tensors.
:param vis_input: List of visual observation tensors.
:param actions: For a continuous Q function (has actions), tensor of actions.
Otherwise, None.
:param memories: Initial memories if using memory. Otherwise, None.
:param sequence_length: Sequence length if using memory.
:param q1_grad: Whether or not to compute gradients for the Q1 network.
:param q2_grad: Whether or not to compute gradients for the Q2 network.
:return: Tuple of two dictionaries, which both map {reward_signal: Q} for Q1 and Q2,
respectively.
"""
# ExitStack allows us to enter the torch.no_grad() context conditionally
with ExitStack() as stack:
if not q1_grad:
stack.enter_context(torch.no_grad())
q1_out, _ = self.q1_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
with ExitStack() as stack:
if not q2_grad:
stack.enter_context(torch.no_grad())
q2_out, _ = self.q2_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):

)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
# Only need grad for q1, as that is used for policy.
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,

q2_grad=False,
)
q1_out, q2_out = self.value_network(
vec_obs,

)
q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
# For discrete, you don't need to backprop through the Q for the policy
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
q1_grad=False,
q2_grad=False,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,

正在加载...
取消
保存