浏览代码

Fix padding issues

/develop/critic-op-lstm-currentmem
Ervin Teng 4 年前
当前提交
21e9785a
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

8
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


for _obs in tensor_obs:
if leftover > 0:
# Pad
# _obs will always be bigger than leftover
padding = torch.zeros_like(
_obs[0 : self.policy.sequence_length - leftover]
)
padding_shape = list(_obs.shape)
padding_shape[0] = self.policy.sequence_length - leftover
padding = torch.zeros(padding_shape)
padded_obs = torch.cat([padding, _obs[0:leftover]])
else:
padded_obs = _obs[0 : self.policy.sequence_length]

start = seq_num * self.policy.sequence_length - leftover
end = (seq_num + 1) * self.policy.sequence_length - leftover
seq_obs.append(_obs[start:end])
assert _obs[start:end].shape[0] == self.policy.sequence_length
values, _mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length
)

正在加载...
取消
保存