浏览代码

Use numpy for random sample in buffer (#2524)

/develop-gpu-test
GitHub 5 年前
当前提交
876aca1e
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/buffer.py

8
ml-agents/mlagents/trainers/buffer.py


mini_batch = Buffer.AgentBuffer()
buff_len = len(next(iter(self.values())))
num_sequences_in_buffer = buff_len // sequence_length
start_idxes = [
random.randint(0, num_sequences_in_buffer - 1) * sequence_length
for _ in range(num_seq_to_sample)
] # Sample random sequence starts
start_idxes = (
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
* sequence_length
) # Sample random sequence starts
for i in start_idxes:
for key in self:
mini_batch[key].extend(self[key][i : i + sequence_length])

正在加载...
取消
保存