问题描述
我目前正在使用生成器来使用tf.data.Dataset.from_generator
来生成训练和验证数据集。我有一个可以为我解决这个问题的类方法:
def build_dataset(self,batch_size=16,shuffle=16,validation=None):
train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation),(tf.float32,tf.float32))
self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
if validation is not None:
val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation),tf.float32))
self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)
问题是将(validation=validation)
传递给我的import_images
生成器创建了Tensorflow不需要的生成器对象,它给了我错误:
TypeError: `generator` must be callable.
因为我必须传递validation
来告诉生成器生成单独的培训和验证版本,所以我需要创建同一生成器的两个版本。它还不允许我传入其他参数来控制训练和验证示例的百分比-这意味着生成器必须是静态的。有什么建议吗?
解决方法
我最近遇到了类似的问题,但是我是初学者,所以不确定是否有帮助。
尝试在您的课程中添加一个通话函数。
下面是引发TypeError: `generator` must be callable.
class DataGen:
def __init__(self,files,data_path):
self.i = 0
self.files=files
self.data_path=data_path
def __load__(self,files_name):
data_path = os.path.join(self.data_path,files_name)
arr_img,arr_mask = load_patch(data_path)
return arr_img,arr_mask
def getitem(self,index):
_img,_mask = self.__load__(self.files[index])
return _img,_mask
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.files):
img_arr,mask_arr = self.getitem(self.i)
self.i += 1
else:
raise StopIteration()
return img_arr,mask_arr
然后我修改了以下代码,它对我有用。
class DataGen:
def __init__(self,mask_arr
def __call__(self):
self.i = 0
return self