如何在不破坏结构的情况下重命名Keras模型的各层?

问题描述

对于some library functionality,我正在尝试重命名给定模型的图层(包括输入图层)。

以下最小示例显示了我使用当前方法(使用TensorFlow 2.3)时遇到的错误

from tensorflow.keras.models import load_model

model = load_model("model.h5")
for layer in model.layers:
    layer._name = layer.name + "_renamed"

model.to_json()
ValueError: The target structure is of type `<class 'tensorflow.python.framework.ops.Tensor'>`
  Tensor("input_1:0",shape=(None,4),dtype=float32)
However the input structure is a sequence (<class 'list'>) of length 0.

model.h5文件可能是这样创建的,例如:

from tensorflow.keras.layers import Input,Dense
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
x = Dense(5,activation='relu',name='a')(inputs)
x = Dense(3,activation='softmax',name='b')(x)
model = Model(inputs=inputs,outputs=x)
model.compile(loss='categorical_crossentropy',optimizer='nadam')
model.save("model.h5")

关于如何解决此问题的任何想法?

解决方法

问题:Keras通过遍历layer._inbound_nodescomparing against model._network_nodes来序列化网络;设置layer._name时,后者会保留原始名称。


解决方案:相应地重命名_network_nodes。工作功能位于底部,示例如下:

from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input,Dense
from tensorflow.keras.models import Model

ipt = Input((16,))
out = Dense(16)(ipt)
model = Model(ipt,out)
model.compile('sgd','mse')

rename(model,model.layers[1],'new_name')
model.save('model.h5')
loaded = load_model('model.h5')

注意layer.name是没有.setter的{​​{3}},这意味着(显然)它不是要设置的。此外,@property被覆盖,除了设置属性外还执行其他步骤-可能是必要的,但不能确切确定它可能还具有什么其他效果。我提供了一个绕过这些的替代方法。最好将其视为临时解决方案;我建议在Github上打开一个Issue,因为API方面的更改已到。


功能

并非万无一失-_get_node_suffix的命名逻辑需要工作(例如dense_1可能与dense_11混淆)。

def rename(model,layer,new_name):
    def _get_node_suffix(name):
        for old_name in old_nodes:
            if old_name.startswith(name):
                return old_name[len(name):]

    old_name = layer.name
    old_nodes = list(model._network_nodes)
    new_nodes = []

    for l in model.layers:
        if l.name == old_name:
            l._name = new_name
            # vars(l).__setitem__('_name',new)  # bypasses .__setattr__
            new_nodes.append(new_name + _get_node_suffix(old_name))
        else:
            new_nodes.append(l.name + _get_node_suffix(l.name))
    model._network_nodes = set(new_nodes)