|
|
|
|
|
|
): |
|
|
|
seq_obs = [] |
|
|
|
for _obs in tensor_obs: |
|
|
|
start = seq_num * self.policy.sequence_length - leftover |
|
|
|
end = (seq_num + 1) * self.policy.sequence_length - leftover |
|
|
|
start = seq_num * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
) |
|
|
|
end = (seq_num + 1) * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
) |
|
|
|
assert _obs[start:end].shape[0] == self.policy.sequence_length |
|
|
|
values, _mem = self.critic.critic_pass( |
|
|
|
seq_obs, _mem, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|