|
|
|
|
|
|
import numpy as np |
|
|
|
import h5py |
|
|
|
from typing import List, BinaryIO |
|
|
|
import itertools |
|
|
|
|
|
|
|
from mlagents_envs.exception import UnityException |
|
|
|
|
|
|
|
|
|
|
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample) |
|
|
|
* sequence_length |
|
|
|
) # Sample random sequence starts |
|
|
|
for i in start_idxes: |
|
|
|
for key in self: |
|
|
|
mini_batch[key].extend(self[key][i : i + sequence_length]) |
|
|
|
for key in self: |
|
|
|
mb_list = [self[key][i : i + sequence_length] for i in start_idxes] |
|
|
|
# See comparison of ways to make a list from a list of lists here: |
|
|
|
# https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists |
|
|
|
mini_batch[key].set(list(itertools.chain.from_iterable(mb_list))) |
|
|
|
return mini_batch |
|
|
|
|
|
|
|
def save_to_file(self, file_object: BinaryIO) -> None: |
|
|
|