|
|
|
|
|
|
from mlagents_envs.timers import timed |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|
|
|
from contextlib import ExitStack |
|
|
|
|
|
|
|
EPSILON = 1e-6 # Small value to avoid divide by zero |
|
|
|
|
|
|
|
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
q1_grad: bool = True, |
|
|
|
q2_grad: bool = True, |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
""" |
|
|
|
Performs a forward pass on the value network, which consists of a Q1 and Q2 |
|
|
|
network. Optionally does not evaluate gradients for either the Q1, Q2, or both. |
|
|
|
:param vec_inputs: List of vector observation tensors. |
|
|
|
:param vis_input: List of visual observation tensors. |
|
|
|
:param actions: For a continuous Q function (has actions), tensor of actions. |
|
|
|
Otherwise, None. |
|
|
|
:param memories: Initial memories if using memory. Otherwise, None. |
|
|
|
:param sequence_length: Sequence length if using memory. |
|
|
|
:param q1_grad: Whether or not to compute gradients for the Q1 network. |
|
|
|
:param q2_grad: Whether or not to compute gradients for the Q2 network. |
|
|
|
:return: Tuple of two dictionaries, which both map {reward_signal: Q} for Q1 and Q2, |
|
|
|
respectively. |
|
|
|
""" |
|
|
|
# ExitStack allows us to enter the torch.no_grad() context conditionally |
|
|
|
with ExitStack() as stack: |
|
|
|
if not q1_grad: |
|
|
|
stack.enter_context(torch.no_grad()) |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
with ExitStack() as stack: |
|
|
|
if not q2_grad: |
|
|
|
stack.enter_context(torch.no_grad()) |
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return q1_out, q2_out |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
|
|
|
) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.squeeze(-1) |
|
|
|
# Only need grad for q1, as that is used for policy. |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
|
|
|
q2_grad=False, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
|
|
|
) |
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
else: |
|
|
|
with torch.no_grad(): |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
# For discrete, you don't need to backprop through the Q for the policy |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q1_grad=False, |
|
|
|
q2_grad=False, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|