Dataset.from_generator:TypeError:`generator`必须是可调用的

问题描述

我目前正在使用生成器来使用tf.data.Dataset.from_generator来生成训练和验证数据集。我有一个可以为我解决这个问题的类方法:

def build_dataset(self,batch_size=16,shuffle=16,validation=None):
    
    train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation),(tf.float32,tf.float32))
    self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
    
    if validation is not None:
        val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation),tf.float32))
        self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)

问题是将(validation=validation)传递给我的import_images生成器创建了Tensorflow不需要的生成器对象,它给了我错误:

TypeError: `generator` must be callable.

因为我必须传递validation来告诉生成器生成单独的培训和验证版本,所以我需要创建同一生成器的两个版本。它还不允许我传入其他参数来控制训练和验证示例的百分比-这意味着生成器必须是静态的。有什么建议吗?

解决方法

我最近遇到了类似的问题,但是我是初学者,所以不确定是否有帮助。

尝试在您的课程中添加一个通话函数。

下面是引发TypeError: `generator` must be callable.

的原始课程
class DataGen:
  def __init__(self,files,data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self,files_name):
    data_path = os.path.join(self.data_path,files_name)
    arr_img,arr_mask = load_patch(data_path)
    return arr_img,arr_mask

  def getitem(self,index):
    _img,_mask = self.__load__(self.files[index])
    return _img,_mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr,mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr,mask_arr

然后我修改了以下代码,它对我有用。

class DataGen:
  def __init__(self,mask_arr
  
  def __call__(self):
    self.i = 0
    return self

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...