如何从经过培训的VGG16网络中获取编码器

问题描述

我正在使用Python 3.7.7。和Tensorflow 2.1.0。

我有一个经过预先训练的VGG16网络,我想获得第一层,即从conv1层到conv5层。

在下图中:

enter image description here

您可以看到卷积编码器-解码器体系结构。我要获取编码器部分,即出现在图像左侧的层:

enter image description here

这只是一个例子,但是如果我从此函数获取VGG16:

def vgg16_encoder_decoder(input_size = (200,200,1)):
    #################################
    # Encoder
    #################################
    inputs = Input(input_size,name = 'input')

    conv1 = Conv2D(64,(3,3),activation = 'relu',padding = 'same',name ='conv1_1')(inputs)
    conv1 = Conv2D(64,name ='conv1_2')(conv1)
    pool1 = MaxPooling2D(pool_size = (2,2),strides = (2,name = 'pool_1')(conv1)

    conv2 = Conv2D(128,name ='conv2_1')(pool1)
    conv2 = Conv2D(128,name ='conv2_2')(conv2)
    pool2 = MaxPooling2D(pool_size = (2,name = 'pool_2')(conv2)
    
    conv3 = Conv2D(256,name ='conv3_1')(pool2)
    conv3 = Conv2D(256,name ='conv3_2')(conv3)
    conv3 = Conv2D(256,name ='conv3_3')(conv3)
    pool3 = MaxPooling2D(pool_size = (2,name = 'pool_3')(conv3)
    
    conv4 = Conv2D(512,name ='conv4_1')(pool3)
    conv4 = Conv2D(512,name ='conv4_2')(conv4)
    conv4 = Conv2D(512,name ='conv4_3')(conv4)
    pool4 = MaxPooling2D(pool_size = (2,name = 'pool_4')(conv4)

    conv5 = Conv2D(512,name ='conv5_1')(pool4)
    conv5 = Conv2D(512,name ='conv5_2')(conv5)
    conv5 = Conv2D(512,name ='conv5_3')(conv5)
    pool5 = MaxPooling2D(pool_size = (2,name = 'pool_5')(conv5)

    #################################
    # Decoder
    #################################
    #conv1 = Conv2DTranspose(512,(2,strides = 2,name = 'conv1')(pool5)

    upsp1 = UpSampling2D(size = (2,name = 'upsp1')(pool5)
    conv6 = Conv2D(512,3,name = 'conv6_1')(upsp1)
    conv6 = Conv2D(512,name = 'conv6_2')(conv6)
    conv6 = Conv2D(512,name = 'conv6_3')(conv6)

    upsp2 = UpSampling2D(size = (2,name = 'upsp2')(conv6)
    conv7 = Conv2D(512,name = 'conv7_1')(upsp2)
    conv7 = Conv2D(512,name = 'conv7_2')(conv7)
    conv7 = Conv2D(512,name = 'conv7_3')(conv7)
    zero1 = ZeroPadding2D(padding =  ((1,0),(1,0)),data_format = 'channels_last',name='zero1')(conv7)

    upsp3 = UpSampling2D(size = (2,name = 'upsp3')(zero1)
    conv8 = Conv2D(256,name = 'conv8_1')(upsp3)
    conv8 = Conv2D(256,name = 'conv8_2')(conv8)
    conv8 = Conv2D(256,name = 'conv8_3')(conv8)

    upsp4 = UpSampling2D(size = (2,name = 'upsp4')(conv8)
    conv9 = Conv2D(128,name = 'conv9_1')(upsp4)
    conv9 = Conv2D(128,name = 'conv9_2')(conv9)

    upsp5 = UpSampling2D(size = (2,name = 'upsp5')(conv9)
    conv10 = Conv2D(64,name = 'conv10_1')(upsp5)
    conv10 = Conv2D(64,name = 'conv10_2')(conv10)

    conv11 = Conv2D(1,name = 'conv11')(conv10)

    model = Model(inputs = inputs,outputs = conv11,name = 'vgg-16_encoder_decoder')

    return model

我训练网络,然后训练它。如何获得编码器零件?换句话说,获得一个模型,该模型仅包含从conv1pool5的原始图层。

我认为可能是这样的:

model_new = Model(input=model_old.layers[0].input,output=model_old.layers[12].output)

解决方法

from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Input,Flatten
from tensorflow.keras import Model

input_shape = (W,H,C)
def encoder(input_shape):
     model = VGG16(include_top=False,input_shape=input_shape)
     F1 = Flatten()(model.get_layer(index=1).output)
     F2 = Flatten()(model.get_layer(index=2).output)
     F3 = Flatten()(model.get_layer(index=3).output)
     F4 = Flatten()(model.get_layer(index=4).output)
     F5 = Flatten()(model.get_layer(index=5).output)   
     M = Model(model.inputs,[F1,F2,F3,F4,F5])  
     return M 

其中 W,H 图像大小和 C 通道数应等于3。

,

要从我的预训练网络中获取伴奏者,我创建了以下功能:

def get_encoder(old_model: Model) -> Model:
  # Get encoder
  encoder_input: Model = Model(inputs=old_model.layers[0].input,outputs=old_model.layers[14].output)

  # Create Global Average Pooling.
  encoder_output = GlobalAveragePooling2D()(encoder_input.layers[-1].output)

  # Create the encoder adding the GAP layer as output.
  encoder: Model = Model(encoder_input.input,encoder_output,name='encoder')

  return encoder

重要的是数字14。这是enconder在原始网络中结束的层。顺便说一句,我终于用U-Net代替了VGG-16,所以这个数字仅适用于U-NET

,

通过省略代码的最后20层,我建议以下代码。

model_new = Model(model_old.input,model_old.layers[-20].output) model_new.summary()

如果我错过了解码器最后20层的计数,您可能需要将其稍微调整为-19或-21才能找到最后一个池5。