|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
""" |
|
|
|
Returns distributions, from which actions can be sampled, and value estimates. |
|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|