|
|
|
|
|
|
for _obs in tensor_obs: |
|
|
|
if leftover > 0: |
|
|
|
# Pad |
|
|
|
# _obs will always be bigger than leftover |
|
|
|
padding = torch.zeros_like( |
|
|
|
_obs[0 : self.policy.sequence_length - leftover] |
|
|
|
) |
|
|
|
padding_shape = list(_obs.shape) |
|
|
|
padding_shape[0] = self.policy.sequence_length - leftover |
|
|
|
padding = torch.zeros(padding_shape) |
|
|
|
padded_obs = torch.cat([padding, _obs[0:leftover]]) |
|
|
|
else: |
|
|
|
padded_obs = _obs[0 : self.policy.sequence_length] |
|
|
|
|
|
|
start = seq_num * self.policy.sequence_length - leftover |
|
|
|
end = (seq_num + 1) * self.policy.sequence_length - leftover |
|
|
|
seq_obs.append(_obs[start:end]) |
|
|
|
assert _obs[start:end].shape[0] == self.policy.sequence_length |
|
|
|
values, _mem = self.critic.critic_pass( |
|
|
|
seq_obs, _mem, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|