在 keras 中实现自定义 GRU 方程

问题描述

GRU Equations

我尝试在我的 keras 自定义 GRU 单元中实现上述等式。我试图寻找在线指南,但没有找到任何有用的东西。这是我第一次从 keras 中的现有 GRU 中实现自定义 GRU。 这是我的尝试

class CGRUCell(Layer):
    def __init__(self,units,activation='tanh',recurrent_activation='sigmoid',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,kernel_constraint=None,recurrent_constraint=None,bias_constraint=None,dropout=0.,recurrent_dropout=0.,implementation=2,reset_after=False,**kwargs):
        super(CGRUCell,self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1.,max(0.,dropout))
        self.recurrent_dropout = min(1.,recurrent_dropout))
        self.implementation = implementation
        self.reset_after = reset_after
        self.state_size = self.units
        self.output_size = self.units
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

    def build(self,input_shape):
        input_dim = input_shape[-1]

        if isinstance(self.recurrent_initializer,initializers.Identity):
            def recurrent_identity(shape,gain=1.,dtype=None):
                del dtype
                return gain * np.concatenate(
                    [np.identity(shape[0])] * (shape[1] // shape[0]),axis=1)

            self.recurrent_initializer = recurrent_identity

        self.kernel = self.add_weight(shape=(input_dim,self.units * 3),name='kernel',initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units,name='recurrent_kernel',initializer=self.recurrent_initializer,regularizer=self.recurrent_regularizer,constraint=self.recurrent_constraint)

        if self.use_bias:
            if not self.reset_after:
                bias_shape = (3 * self.units,)
            else:
                # separate biases for input and recurrent kernels
                # Note: the shape is intentionally different from CuDNNGRU biases
                # `(2 * 3 * self.units,)`,so that we can distinguish the classes
                # when loading and converting saved weights.
                bias_shape = (2,3 * self.units)
            self.bias = self.add_weight(shape=bias_shape,name='bias',initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint)
            if not self.reset_after:
                self.input_bias,self.recurrent_bias = self.bias,None
            else:
                # NOTE: need to flatten,since slicing in cntk gives 2D array
                self.input_bias = K.flatten(self.bias[0])
                self.recurrent_bias = K.flatten(self.bias[1])
        else:
            self.bias = None

        # update gate
        self.kernel_z = self.kernel[:,:self.units]
        self.recurrent_kernel_z = self.recurrent_kernel[:,:self.units]
        # reset gate
        self.kernel_r = self.kernel[:,self.units: self.units * 2]
        self.recurrent_kernel_r = self.recurrent_kernel[:,self.units:
                                                        self.units * 2]
        # new gate
        self.kernel_h = self.kernel[:,self.units * 2:]
        self.recurrent_kernel_h = self.recurrent_kernel[:,self.units * 2:]

        if self.use_bias:
            # bias for inputs
            self.input_bias_z = self.input_bias[:self.units]
            self.input_bias_r = self.input_bias[self.units: self.units * 2]
            self.input_bias_h = self.input_bias[self.units * 2:]
            # bias for hidden state - just for compatibility with CuDNN
            if self.reset_after:
                self.recurrent_bias_z = self.recurrent_bias[:self.units]
                self.recurrent_bias_r = (
                    self.recurrent_bias[self.units: self.units * 2])
                self.recurrent_bias_h = self.recurrent_bias[self.units * 2:]
        else:
            self.input_bias_z = None
            self.input_bias_r = None
            self.input_bias_h = None
            if self.reset_after:
                self.recurrent_bias_z = None
                self.recurrent_bias_r = None
                self.recurrent_bias_h = None
        self.built = True

    def call(self,inputs,states,training=None):
        h_tm1 = states[0]  # prevIoUs memory

        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),self.dropout,training=training,count=3)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(h_tm1),self.recurrent_dropout,count=3)

        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        if self.implementation == 1:
            if 0. < self.dropout < 1.:
                inputs_z = inputs * dp_mask[0]
                inputs_r = inputs * dp_mask[1]
                inputs_h = inputs * dp_mask[2]
            else:
                inputs_z = inputs
                inputs_r = inputs
                inputs_h = inputs

            if 0. < self.recurrent_dropout < 1.:
                h_tm1_z = h_tm1 * rec_dp_mask[0]
                h_tm1_r = h_tm1 * rec_dp_mask[1]
                h_tm1_h = h_tm1 * rec_dp_mask[2]
            else:
                h_tm1_z = h_tm1 
                h_tm1_r = h_tm1 
                h_tm1_h = h_tm1 

            x_z = K.dot(h_tm1_z,K.transpose(self.kernel_z) )
            x_r = K.dot(h_tm1_r,K.transpose(self.kernel_r) )
            x_h = K.dot(h_tm1_h,K.transpose(self.kernel_h) )
            if self.use_bias:
                x_z = K.bias_add(x_z,self.input_bias_z)
                x_r = K.bias_add(x_r,self.input_bias_r)
                x_h = K.bias_add(x_h,self.input_bias_h)

            recurrent_z = K.dot(inputs_z,self.recurrent_kernel_z)
            recurrent_r = K.dot( inputs_r,self.recurrent_kernel_r)
            if self.reset_after and self.use_bias:
                recurrent_z = K.bias_add(recurrent_z,self.recurrent_bias_z)
                recurrent_r = K.bias_add(recurrent_r,self.recurrent_bias_r)
            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            # reset gate applied after/before matrix multiplication
            if self.reset_after:
                recurrent_h = K.dot( inputs_h,self.recurrent_kernel_h)
                if self.use_bias:
                    recurrent_h = K.bias_add(recurrent_h,self.recurrent_bias_h)
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = K.dot(r * h_tm1_h,self.recurrent_kernel_h)

            hh = self.activation(x_h + recurrent_h)
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]

            # inputs projected by all gate matrices at once
            matrix_x = K.dot(inputs,self.kernel)
            if self.use_bias:
                # biases: bias_z_i,bias_r_i,bias_h_i
                matrix_x = K.bias_add(matrix_x,self.input_bias)
            x_z = matrix_x[:,:self.units]
            x_r = matrix_x[:,self.units: 2 * self.units]
            x_h = matrix_x[:,2 * self.units:]

            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]

            if self.reset_after:
                # hidden state projected by all gate matrices at once
                matrix_inner = K.dot(h_tm1,self.recurrent_kernel)
                if self.use_bias:
                    matrix_inner = K.bias_add(matrix_inner,self.recurrent_bias)
            else:
                # hidden state projected separately for update/reset and new
                matrix_inner = K.dot(h_tm1,self.recurrent_kernel[:,:2 * self.units])

            recurrent_z = matrix_inner[:,:self.units] #Changes Expected Here
            recurrent_r = matrix_inner[:,self.units: 2 * self.units]

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            if self.reset_after:
                recurrent_h = r * matrix_inner[:,2 * self.units:]
            else:
                recurrent_h = K.dot(r * h_tm1,2 * self.units:])

            hh = self.activation(x_h + recurrent_h)

        # prevIoUs and candidate state mixed by update gate
        h = (1 - z) * h_tm1 + z * hh
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return h,[h]

    def get_config(self):
        config = {'units': self.units,'activation': activations.serialize(self.activation),'recurrent_activation':
                      activations.serialize(self.recurrent_activation),'use_bias': self.use_bias,'kernel_initializer':
                      initializers.serialize(self.kernel_initializer),'recurrent_initializer':
                      initializers.serialize(self.recurrent_initializer),'bias_initializer': initializers.serialize(self.bias_initializer),'kernel_regularizer':
                      regularizers.serialize(self.kernel_regularizer),'recurrent_regularizer':
                      regularizers.serialize(self.recurrent_regularizer),'bias_regularizer': regularizers.serialize(self.bias_regularizer),'kernel_constraint': constraints.serialize(self.kernel_constraint),'recurrent_constraint':
                      constraints.serialize(self.recurrent_constraint),'bias_constraint': constraints.serialize(self.bias_constraint),'dropout': self.dropout,'recurrent_dropout': self.recurrent_dropout,'implementation': self.implementation,'reset_after': self.reset_after}
        base_config = super(CGRUCell,self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

需要帮助知道这是否正确..

来自 tf1 的类似代码

from keras.layers import Dense,LeakyReLU,Reshape,RNN,Lambda,Conv3D,Batchnormalization,ReLU,Add

def fcconv3d_layer(h_t,feature_x,filters,n_gru_vox):
    out_shape = h_t.shape    
    fc_output = Dense(n_gru_vox * n_gru_vox * n_gru_vox * filters) (feature_x)
    fc_output = ReLU()(fc_output)
    fc_output = Lambda(lambda x : tf.reshape(x,out_shape))(fc_output)
    scov3d0   = Conv3D(filters=filters,kernel_size=3,padding='same')(h_t)
    scov3d1   = Batchnormalization()(scov3d0)
    scov3d2   = LeakyReLU()(scov3d1)

    h_tn      = Add()([fc_output,scov3d2])
    return h_tn

def recurrence(h_t,n_gru_vox,index):
    u_t = tf.keras.activations.sigmoid(fcconv3d_layer(h_t,n_gru_vox))
    r_t = tf.keras.activations.sigmoid( fcconv3d_layer(h_t,n_gru_vox))
    tt = u_t * tf.keras.activations.tanh(fcconv3d_layer(r_t * h_t,n_gru_vox))
    h_tn = (1.0 - u_t) * h_t + u_t * tf.keras.activations.tanh(fcconv3d_layer(r_t * h_t,n_gru_vox))
    return h_tn

def build_3dgru(features):  
        shape = features.get_shape().as_list()
        h = [None for _ in range(shape[2] + 1)]
        # h[0] = tf.zeros(shape = [1,n_deconv_filters[0]],dtype = tf.float32) if shape[0]==None else tf.zeros(shape = [shape[0],dtype = tf.float32)
        h[0] = tf.zeros(shape = [shape[1],dtype = tf.float32)

        for i in range(shape[2]):
            fc = features[:,i,...]
            h[i + 1] =  Lambda (lambda x : recurrence(x,fc,n_deconv_filters[0],i))(h[i])
        # [bs,4,128]
        return h[-1]

将上述内容翻译成 keras 也是一个答案..

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)