|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
all_net_inputs = [] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
|
|
|
net_inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
all_net_inputs = [] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
|
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
all_net_inputs = [] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
|
|
|
|