浏览代码

[refactor] Optimize buffer sample_minibatch (#4508)

* Optimize buffer sample_minibatch
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
215b35c6
共有 1 个文件被更改,包括 6 次插入3 次删除
  1. 9
      ml-agents/mlagents/trainers/buffer.py

9
ml-agents/mlagents/trainers/buffer.py


import numpy as np
import h5py
from typing import List, BinaryIO
import itertools
from mlagents_envs.exception import UnityException

np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
* sequence_length
) # Sample random sequence starts
for i in start_idxes:
for key in self:
mini_batch[key].extend(self[key][i : i + sequence_length])
for key in self:
mb_list = [self[key][i : i + sequence_length] for i in start_idxes]
# See comparison of ways to make a list from a list of lists here:
# https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists
mini_batch[key].set(list(itertools.chain.from_iterable(mb_list)))
return mini_batch
def save_to_file(self, file_object: BinaryIO) -> None:

正在加载...
取消
保存