在 Keras 中合并多个模型tensorflow

问题描述

在这里做了很多努力之后,我的问题是,

我有两个模型,两个模型都可以检测 2-2 个类。众所周知,我们可以使用 FunctionalAPI 合并两个模型。我试过了,但没有得到想要的结果。

我的目标:我想合并这些模型,更新后的模型应该有(1个输入,4个输出)。

This app can't run on your PC. 
To find a version for your PC,check with software publisher
inputs = tf.keras.Input(shape=(50,50,1))
y_1 = f1_Model(inputs)
y_2 = f2(inputs)
outputs = tf.concat([y_1,y_2],axis=0)
new_model = keras.Model(inputs,outputs)
new_model.summary()

当我在其中传递图像时,它给出了错误的结果。我不知道我哪里出错了。

Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None,1)]  0                                            
__________________________________________________________________________________________________
sequential (Sequential)         (None,2)            203874      input_2[0][0]                    
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None,2)            203874      input_2[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_concat (TensorFlowO [(None,2)]          0           sequential[1][0]                 
                                                                 sequential_1[1][0]               
==================================================================================================
Total params: 407,748
Trainable params: 407,748
Non-trainable params: 0
__________________________________________________________________________________________________

解决方法

据我了解,您想对 4 个类别进行分类,为此,您有 2 个模型,每个模型对 2 个类别进行分类。
截至目前,您的 f1 和 f2 模型输出 softmax activation 的结果,因此首先,您必须将其删除并仅输出 logits 或仅输出 relu activation。之后,正如@dmg2 所提到的,您现在必须在 axis=1 中设置 tf.concat,最后您必须通过新的 softmax 激活传递输出。在那之后,我希望你能训练你的模型。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...