浏览代码

Don't pad when not needed

/develop/rear-pad
Ervin Teng 4 年前
当前提交
a9666a0b
共有 1 个文件被更改,包括 9 次插入12 次删除
  1. 21
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

21
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


"""
num_experiences = tensor_obs[0].shape[0]
all_next_memories = AgentBufferField()
# The 1st sequence are the ones that are padded. So if seq_len = 3 and
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the 1st sequence is [pad,pad,obs].
# Compute the number of elements in this padded seq.
leftover = num_experiences % self.policy.sequence_length

# Compute values for the potentially truncated initial sequence
first_seq_obs = _obs[0:leftover]
# Pad
padding_shape = list(_obs.shape)
padding_shape[0] = self.policy.sequence_length - leftover
padding = torch.zeros(padding_shape)
padded_obs = torch.cat([padding, _obs[0:leftover]])
first_seq_obs = _obs[0:leftover]
first_seq_len = leftover
padded_obs = _obs[0 : self.policy.sequence_length]
seq_obs.append(padded_obs)
first_seq_obs = _obs[0 : self.policy.sequence_length]
first_seq_len = self.policy.sequence_length
seq_obs.append(first_seq_obs)
seq_obs, _mem, sequence_length=self.policy.sequence_length
seq_obs, _mem, sequence_length=first_seq_len
# Trim out padded part, i.e. get last leftover number of elements
signal_name: [init_values[signal_name][-leftover:]]
signal_name: [init_values[signal_name]]
for signal_name in init_values.keys()
}

正在加载...
取消
保存