浏览代码

Change update buffer to float32 instead of float64 (#2461)

- Reduces memory usage of buffer.
/develop-gpu-test
GitHub 5 年前
当前提交
bf375235
共有 1 个文件被更改,包括 5 次插入3 次删除
  1. 8
      ml-agents/mlagents/trainers/buffer.py

8
ml-agents/mlagents/trainers/buffer.py


if batch_size * training_length > len(self):
padding = np.array(self[-1]) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:]
[padding] * (training_length - leftover) + self[:],
dtype=np.float32,
self[len(self) - batch_size * training_length :]
self[len(self) - batch_size * training_length :],
dtype=np.float32,
)
else:
# The sequences will have overlapping elements

tmp_list = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list)
return np.array(tmp_list, dtype=np.float32)
def reset_field(self):
"""

正在加载...
取消
保存