|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
""" |
|
|
|
Performs a forward pass on the value network, which consists of a Q1 and Q2 |
|
|
|
network. Optionally does not evaluate gradients for either the Q1, Q2, or both. |
|
|
|
:param vec_inputs: List of vector observation tensors. |
|
|
|
:param vis_input: List of visual observation tensors. |
|
|
|
:param net_inputs: List of observation tensors. |
|
|
|
:param actions: For a continuous Q function (has actions), tensor of actions. |
|
|
|
Otherwise, None. |
|
|
|
:param memories: Initial memories if using memory. Otherwise, None. |
|
|
|
|
|
|
if not q1_grad: |
|
|
|
stack.enter_context(torch.no_grad()) |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
net_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
|
|
|
stack.enter_context(torch.no_grad()) |
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
net_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
|
|
|
for name in self.reward_signals: |
|
|
|
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|
|
|
|
|
|
|
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|
|
|
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] |
|
|
|
obs = ModelUtils.list_to_tensor_list( |
|
|
|
AgentBuffer.obs_list_to_obs_batch(batch["obs"]) |
|
|
|
) |
|
|
|
next_obs = ModelUtils.list_to_tensor_list( |
|
|
|
AgentBuffer.obs_list_to_obs_batch(batch["next_obs"]) |
|
|
|
) |
|
|
|
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|
|
|
|
|
|
torch.zeros_like(next_memories) if next_memories is not None else None |
|
|
|
) |
|
|
|
|
|
|
|
vis_obs: List[torch.Tensor] = [] |
|
|
|
next_vis_obs: List[torch.Tensor] = [] |
|
|
|
if self.policy.use_vis_obs: |
|
|
|
vis_obs = [] |
|
|
|
for idx, _ in enumerate( |
|
|
|
self.policy.actor_critic.network_body.visual_processors |
|
|
|
): |
|
|
|
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|
|
|
vis_obs.append(vis_ob) |
|
|
|
next_vis_ob = ModelUtils.list_to_tensor( |
|
|
|
batch["next_visual_obs%d" % idx] |
|
|
|
) |
|
|
|
next_vis_obs.append(next_vis_ob) |
|
|
|
|
|
|
|
# Copy normalizers from policy |
|
|
|
self.value_network.q1_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
(sampled_actions, _, log_probs, _, _) = self.policy.sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
obs, |
|
|
|
masks=act_masks, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
|
|
|
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length |
|
|
|
obs, memories, sequence_length=self.policy.sequence_length |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
obs, |
|
|
|
sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
obs, |
|
|
|
squeezed_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
|
|
|
# For discrete, you don't need to backprop through the Q for the policy |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q1_grad=False, |
|
|
|
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
next_obs, |
|
|
|
memories=next_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|