|
|
|
|
|
|
init_values, _mem = self.critic.critic_pass( |
|
|
|
seq_obs, _mem, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
# Trim out padded part |
|
|
|
# Trim out padded part, i.e. get last leftover number of elements |
|
|
|
signal_name: [init_values[signal_name][leftover:]] |
|
|
|
signal_name: [init_values[signal_name][-leftover:]] |
|
|
|
for signal_name in init_values.keys() |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
next_value_estimate[k] = 0.0 |
|
|
|
if agent_id in self.critic_memory_dict: |
|
|
|
self.critic_memory_dict.pop(agent_id) |
|
|
|
assert len(value_estimates["extrinsic"]) == batch.num_experiences |
|
|
|
return value_estimates, next_value_estimate, all_next_memories |