加载张量流图像并创建补丁

问题描述

我正在使用image_dataset_from_directory将非常大的RGB图像数据集从磁盘加载到Dataset中。例如,

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,label_mode=None,seed=1,subset='training',validation_split=0.1)

例如,数据集已将100000张图像分组为32个大小的批次,从而产生规格为tf.data.Dataset的{​​{1}}

我想从图像中提取补丁以创建图像空间尺寸为例如64x64的新(batch=32,width=256,height=256,channels=3)

因此,我想创建一个新数据集,其中仍旧包含32个批次的400000补丁,其中tf.data.Dataset的规格为tf.data.Dataset

我已经研究了window方法extract_patches函数,但是从文档中尚不清楚如何使用它们创建新的数据集,我需要开始对补丁进行培训。 (batch=32,width=64,height=64,channels=3)似乎适用于一维张量,window似乎适用于数组而不是数据集。

关于如何实现此目标的任何建议?

更新:

只是为了澄清我的需求。我试图避免在磁盘上手动创建补丁。一,这在磁盘上是站不住脚的。二,补丁大小不固定。实验将在几种补丁大小上进行。因此,我不想在磁盘上手动执行补丁创建或在内存中手动加载映像并执行补丁。我宁愿让tensorflow作为流水线工作流的一部分来处理补丁创建,以最大程度地减少磁盘和内存使用。

解决方法

我相信您可以使用python类生成器。您可以根据需要将此生成器传递给model.fit函数。我实际上曾经将其用于标签预处理。

我编写了以下数据集生成器,该数据集生成器从您的数据集中加载了一个批次,并根据tile_shape参数将该批次中的图像拆分为多个图像。如果有足够的图像,则返回下一批。

在示例中,我使用了一个简单的from_tensor_slices数据集进行简化。您当然可以用您的替换它。

import tensorflow as tf

class TileDatasetGenerator:
    
    def __init__(self,dataset,batch_size,tile_shape):
        self.dataset_iterator = iter(dataset)
        self.batch_size = batch_size
        self.tile_shape = tile_shape
        self.image_queue = None
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self._has_queued_enough_for_batch():
            return self._dequeue_batch()
        
        batch = next(self.dataset_iterator)
        self._split_images(batch)    
        return self.__next__()
            
    def _has_queued_enough_for_batch(self):
        return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
    
    def _dequeue_batch(self):
        batch,remainder = tf.split(self.image_queue,[self.batch_size,-1],axis=0)
        self.image_queue = remainder
        return batch
        
    def _split_images(self,batch):
        batch_shape = tf.shape(batch)
        batch_splitted = tf.reshape(batch,shape=[-1,self.tile_shape[0],self.tile_shape[1],batch_shape[-1]])
        if self.image_queue is None:
            self.image_queue = batch_splitted
        else:
            self.image_queue = tf.concat([self.image_queue,batch_splitted],axis=0)
            


dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128,64,3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset,batch_size = 16,tile_shape = [32,32])

for batch in generator:
    tf.print(tf.shape(batch))

修改: 如果需要,可以将生成器转换为tf.data.Dataset,但是它要求您向生成器添加__call__函数,以返回迭代器(在这种情况下为self)。

new_dataset = tf.data.Dataset.from_generator(generator,output_types=(tf.int64))
,

您要寻找的是tf.image.extract_patches。这是一个示例:

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

data = tfds.load('mnist',split='test',as_supervised=True)

get_patches = lambda x,y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x,0),sizes=[1,14,1],strides=[1,rates=[1,1,padding='VALID'),(4,1)),y)

data = data.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1,hspace=.2)
images,labels = next(iter(data))
for index,image in enumerate(images):
    ax = plt.subplot(2,2,index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

enter image description here