问题描述
我遇到了同样的问题,我设法通过定义一个__next__
方法解决了这个问题:
class My_Generator(Sequence):
def __init__(self, image_filenames, labels, batch_size):
self.image_filenames, self.labels = image_filenames, labels
self.batch_size = batch_size
self.n = 0
self.max = self.__len__()
def __len__(self):
return np.ceil(len(self.image_filenames) / float(self.batch_size))
def __getitem__(self, idx):
batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
def __next__(self):
if self.n >= self.max:
self.n = 0
result = self.__getitem__(self.n)
self.n += 1
return result
解决方法
由于RAM内存的限制,我遵循了这些指令,并构建了一个生成器,该生成器可以绘制小批量并将其传递给Keras的fit_generator。但是,即使我继承了Sequence,Keras也无法使用多重处理来准备队列。
这是我的多处理生成器。
class My_Generator(Sequence):
def __init__(self,image_filenames,labels,batch_size):
self.image_filenames,self.labels = image_filenames,labels
self.batch_size = batch_size
def __len__(self):
return np.ceil(len(self.image_filenames) / float(self.batch_size))
def __getitem__(self,idx):
batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name),(200,200))
for file_name in batch_x]),np.array(batch_y)
主要功能:
batch_size = 100
num_epochs = 10
train_fnames = []
mask_training = []
val_fnames = []
mask_validation = []
我希望生成器按ID分别在不同线程中读取文件夹中的批处理(其中ID看起来像:{number} .csv用于原始图像,{number}
_label.csv用于掩码图像)。最初,我建立了另一个更优雅的类,将每个数据存储在一个.h5文件而不是目录中。但是阻止了同样的问题。因此,如果您有执行此操作的代码,那么我也是。
for dirpath,_,fnames in os.walk('./train/'):
for fname in fnames:
if 'label' not in fname:
training_filenames.append(os.path.abspath(os.path.join(dirpath,fname)))
else:
mask_training.append(os.path.abspath(os.path.join(dirpath,fname)))
for dirpath,fnames in os.walk('./validation/'):
for fname in fnames:
if 'label' not in fname:
validation_filenames.append(os.path.abspath(os.path.join(dirpath,fname)))
else:
mask_validation.append(os.path.abspath(os.path.join(dirpath,fname)))
my_training_batch_generator = My_Generator(training_filenames,mask_training,batch_size)
my_validation_batch_generator = My_Generator(validation_filenames,mask_validation,batch_size)
num_training_samples = len(training_filenames)
num_validation_samples = len(validation_filenames)
在此,该模型不在范围内。我相信这不是模型的问题,所以我不会粘贴它。
mdl = model.compile(...)
mdl.fit_generator(generator=my_training_batch_generator,steps_per_epoch=(num_training_samples // batch_size),epochs=num_epochs,verbose=1,validation_data=None,#my_validation_batch_generator,# validation_steps=(num_validation_samples // batch_size),use_multiprocessing=True,workers=4,max_queue_size=2)
该错误表明我创建的类不是Iterator:
Traceback (most recent call last):
File "test.py",line 141,in <module> max_queue_size=2)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py",line 2177,in fit_generator
initial_epoch=initial_epoch)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py",line 147,in fit_generator
generator_output = next(output_generator)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/utils/data_utils.py",line 831,in get six.reraise(value.__class__,value,value.__traceback__)
File "/anaconda3/lib/python3.6/site-packages/six.py",line 693,in reraise
raise value
TypeError: 'My_Generator' object is not an iterator