如何计算允许的最大batch_size?

问题描述

我有一个从数据帧创建数组的函数,在处理过程中需要更多空间。 因此,我实现了一个批量参数,该参数遍历数据框并选择batch_len = 1000 行左右的区域,对其进行预处理,并将它们连接到我的结果numpy数组中。

为了不间断地进行小批量生产,我想创建一个计算函数来确定一次允许处理多少行(batch_len)。

因此,我需要知道多少个numpy数组? 是否有(或多或少)恒定的大小?

我总是知道什么:行和列的数量

我主要使用float32或float64进行计算(但我愿意始终像其float64一样使该功能更简单)

解决方法

您需要获取数组大小(以字节为单位)。这可以通过数组的.nbytes属性来完成:

import numpy as np

x = np.zeros((10,10,10))
print('{} KB'.format(x.nbytes / 1024)) 
7.8125 KB

接下来,您需要定义批次,以便每个批次都可以放入您的RAM。请注意,某些处理函数有时可能会复制内存中的数组,因此当Python脚本使用top程序(在Linux之类的OS上)或htop执行Python脚本时,您可能想监视内存的使用。 >