有没有办法构建一个以指定角度随机旋转的keras预处理层?

问题描述

我正在从事天文图像分类项目,目前正在使用 keras 构建 CNN。

我正在尝试构建一个预处理管道,以使用 keras/tensorflow 层来扩充我的数据集。为了简单起见,我想实现 dihedral group随机变换(即对于方形图像,90 度旋转和翻转),但似乎 tf.keras.preprocessing.image.random_rotation 只允许在连续服从均匀分布的选择范围。

我想知道是否有办法从指定的度数列表中进行选择,在我的例子中是 [0,90,180,270]。

解决方法

幸运的是,有一个 tensorflow 函数可以满足您的需求:tf.image.rot90。下一步是将该函数包装到自定义 PreprocessingLayer 中,使其随机执行。

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer

class RandomRot90(PreprocessingLayer):
    def __init__(self,name=None,**kwargs) -> None:
        super(RandomRot90,self).__init__(name=name,**kwargs)
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
    
    def call(self,inputs,training=True):
        if training is None:
            training = K.learning_phase()
        
        def random_rot90():
            # random int between 0 and 3
            rot = tf.random.uniform((),4,dtype=tf.int32)
            return tf.image.rot90(inputs,k=rot)
        
        # if not training,do nothing
        outputs = tf.cond(training,random_rot90,lambda:inputs)
        outputs.set_shape(inputs.shape)
        return outputs
    
    def compute_output_shape(self,input_shape):
        return input_shape
  • 请注意,如果您希望能够使用该层保存和加载模型,您可能需要实现 get_config。 (见documentation
  • 另请注意,如果您的输入不是正方形(高度!= 宽度),则该层可能会失败。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...