使用Tensorflow异物检测API V2的多通道输入

问题描述

我想使用5通道图像在Tensorflow V2对象检测API中构建网络。但是,我仍然坚持如何使用Tensorflow 2.2框架修改第一卷积层的权重。

我已经从V2模型动物园下载了经过预训练的RetinaNet。然后,我尝试了以下操作来修改检查点第一层中的权重并将其保存回去:

tf_path = tf.train.latest_checkpoint('./RetinaNet/checkpoint/')
init_vars = tf.train.list_variables(tf_path)
tf_vars = {}
for name,shape in init_vars:

    array = tf.train.load_variable(tf_path,name)
    try:
        if shape[2]==3:#look for a layer who's 3rd input dimension is 3 i.e. the 1st convolutional layer
            array=np.concatenate((array,array[:,:,:2,:]),axis=2)
            array=array.astype('float32')
            tf_vars[name]=tf.Variable(array)
            
        else:
            tf_vars[name]=tf.Variable(array)
            
    except:
        tf_vars[name]=tf.Variable(array)
        
        
saver = tf.compat.v1.train.Saver(var_list=tf_vars)
sess = tf.compat.v1.Session()
saver.save(sess,'./RetinaNet/checkpoint/ckpt-0')

我重新加载了模型,以确保第一卷积层已更改-一切正常。

但是当我训练模型时,出现以下错误: 使用输入Tensor(“ input_1:0”,shape =(None,None,None,3),dtype = float32)的形状(None,None,None,3)构造模型,但是在不兼容的输入上调用了该模型形状(64、128、128、5)

这使我相信我修改权重的方法毕竟不太“ OK”。有人会教您如何正确修改这些权重的提示吗?

谢谢

解决方法

现在可以使用,但是解决方案非常笨拙……这还意味着不使用模型动物园中的预训练权重进行训练-因此您需要注释所有与配置文件中的fine_tune_checkpoint有关的内容。 然后,转到。\ Lib \ site-packages \ official \ vision \ image_classification \ efficiencynet并更改efficiencynet_model.py和efficiencynet_config.py中的输入通道数和类数。