|
|
|
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size, -1) |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size, dim=-1) |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|