如何将Keras Multiply与tf.Variable结合使用?

问题描述

如何将tf.keras.layerstf.Variable相乘?

上下文:我正在创建一个样本依赖的卷积滤波器,它由一个通用滤波器W组成,该滤波器通过样本依赖的移位+缩放进行转换。因此,卷积原始滤波器W转换为aW + b,其中a是样本依赖的缩放比例,而b是样本依赖的移位。此方法一个应用是训练一个自动编码器,其中样本相关性是标签,因此每个标签都会移动/缩放卷积滤波器。由于样本/标签相关的卷积,我使用tf.nn.Conv2d,它将实际过滤器作为输入(而不只是过滤器的数量/大小),并使用带有tf.map_fn的lambda层来应用不同的“每个样本的“转化过滤器”(基于标签)。尽管细节不同,但是本文讨论了这种依赖样本的卷积方法Tensorflow: Convolutions with different filter for each sample in the mini-batch

这就是我的想法:

input_img = keras.Input(shape=(28,28,1))  
label = keras.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters,activation=None,name='shift')(label) # (32,)
scale = layers.Dense(num_filters,name='scale')(label) # (32,)

# filter is of shape (filter_h,filter_w,input channels,output filters)
filter = tf.Variable(tf.ones((3,3,input_img.shape[-1],num_filters)))
# Todo: need to shift and scale -> shift*(filter) + scale along each output filter dimension (32 filter dimensions)

我不确定如何实现Todo部分。我在考虑tf.keras.layers.Multiply()用于缩放,而tf.keras.layers.Add()用于移位,但是它们似乎不适用于tf。据我所知,这是可变的。我该如何解决?假设尺寸/形状广播可行,我想做这样的事情(注意:输出仍应与var相同,并且仅沿32个输出滤波器尺寸中的每个比例缩放)

output = tf.keras.layers.Multiply()([var,scale]) 

解决方法

这需要一些工作,并且需要一个自定义图层。例如,您cannot use tf.Variable with tf.keras.Lambda

class ConvNorm(layers.Layer):
    def __init__(self,height,width,n_filters):
        super(ConvNorm,self).__init__()
        self.height = height  
        self.width = width
        self.n_filters = n_filters

    def build(self,input_shape):              
        self.filter = self.add_weight(shape=(self.height,self.width,input_shape[-1],self.n_filters),initializer='glorot_uniform',trainable=True)        
        # TODO: Add bias too


    def call(self,x,scale,shift):
        shift_reshaped = tf.expand_dims(tf.expand_dims(shift,1),1)
        scale_reshaped = tf.expand_dims(tf.expand_dims(scale,1)

        norm_conv_out = tf.nn.conv2d(x,self.filter*scale + shift,strides=(1,1,padding='SAME')
                
        return norm_conv_out

使用图层

import tensorflow as tf
import tensorflow.keras.layers as layers

input_img = layers.Input(shape=(28,28,1))  
label = layers.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters,activation=None,name='shift')(label) # (32,)
scale = layers.Dense(num_filters,name='scale')(label) # (32,)

conv_norm_out = ConvNorm(3,3,32)(input_img,shift)
print(norm_conv_out.shape)

注意:请注意,我没有添加偏见。对于卷积层,您也将需要偏见。但这很简单。