Keras的“ plot_model”显示了错误的嵌套模型图自动编码器

问题描述

当我创建具有多个输入和输出自动编码器体系结构时,plot_model图不会按预期显示(问题以红色突出显示)。

我认为发生第一个问题是因为我对自动编码器使用了encoder.inputs。但是,为自动编码器创建新的输入层会导致我出错(图形断开)。

问题可能在我这儿,而不是Keras中的错误,所以希望有更多经验的人可以指导我正确的方向。

ps。我不能只使用一个自动编码器模型,这样可以避免此问题,因为在GAN设置中还使用了相同的编码器和解码器模型。 (该实现目前仅适用于单个输入和切片,但我真的想切换到多个输入,因为它感觉更干净)

下面的代码和体系结构图像:

enter image description here

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model

# Config encoder
state_inputs = Input(shape=7,name="encoder_state_inputs")
action_inputs = Input(shape=4,name="encoder_action_inputs")
encoder_layer = Dense(11,activation=LeakyReLU(alpha=0.2),name=f"encoder_layer")
classification_layer = Dense(4,activation="softmax",name=f"classification")
latent_layer = Dense(5,name="latent_space")

# Build encoder
x = concatenate([state_inputs,action_inputs])
x = encoder_layer(x)
classifier = classification_layer(x)
latent = latent_layer(x)
encoder = Model([state_inputs,action_inputs],[classifier,latent],name="encoder")
plot_model(encoder,"encoder.png")


# Config decoder
classifier_inputs = Input(shape=4,name="classifier_inputs")
latent_inputs = Input(shape=5,name="latent_inputs")
decoder_layer = Dense(11,name=f"decoder_layer")
state_reconstruction_layer = Dense(7,name=f"state_reconstruction")
action_reconstruction_layer = Dense(4,name=f"action_reconstruction")

# Build decoder
x = concatenate([classifier_inputs,latent_inputs])
x = decoder_layer(x)
state_reconstruction = state_reconstruction_layer(x)
action_reconstruction = action_reconstruction_layer(x)
decoder = Model([classifier_inputs,latent_inputs],[state_reconstruction,action_reconstruction],name="decoder")
plot_model(decoder,"decoder.png")

# Build autoencoder
encoded = encoder(encoder.inputs)
decoded = decoder(encoded)
autoencoder = Model(encoder.inputs,outputs=[classifier,decoded],name="autoencoder")

plot_model(autoencoder,"autoencoder.png",show_shapes=True,expand_nested=True)

解决方法

我可以通过在自动编码器中添加两个输入层来防止出现第一个问题。

第二个问题(多个输出仅连接到一个输入)似乎是一个已知的错误:https://github.com/tensorflow/tensorflow/issues/42101