如何在 Keras 中使用模型子类化? 总结和情节

问题描述

使用顺序 API 编写以下模型:

config = {
    'learning_rate': 0.001,'lstm_neurons':32,'lstm_activation':'tanh','dropout_rate': 0.08,'batch_size': 128,'dense_layers':[
      {'neurons': 32,'activation': 'relu'},{'neurons': 32,]
}

def get_model(num_features,output_size):
    opt = Adam(learning_rate=0.001)
    model = Sequential()
    model.add(Input(shape=[None,num_features],dtype=tf.float32,ragged=True))
    model.add(LSTM(config['lstm_neurons'],activation=config['lstm_activation']))
    model.add(Batchnormalization()) 
    if 'dropout_rate' in config:
      model.add(Dropout(config['dropout_rate']))

    for layer in config['dense_layers']:
      model.add(Dense(layer['neurons'],activation=layer['activation']))
      model.add(Batchnormalization()) 
      if 'dropout_rate' in layer:
        model.add(Dropout(layer['dropout_rate']))

    model.add(Dense(output_size,activation='sigmoid'))
    model.compile(loss='mse',optimizer=opt,metrics=['mse'])
    print(model.summary())
    return model

在使用分布式训练框架时,我需要将语法转换为使用模型子类化。 我看过at the docs,但不知道怎么做。

解决方法

这是一个等效的子类实现。虽然我没有测试。

import tensorflow as tf 

# your config 
config = {
    'learning_rate': 0.001,'lstm_neurons':32,'lstm_activation':'tanh','dropout_rate': 0.08,'batch_size': 128,'dense_layers':[
      {'neurons': 32,'activation': 'relu'},{'neurons': 32,]
}
# Subclassed API Model 
class MySubClassed(tf.keras.Model):
    def __init__(self,output_size):
        super(MySubClassed,self).__init__()
        self.lstm = tf.keras.layers.LSTM(config['lstm_neurons'],activation=config['lstm_activation'])
        self.bn = tf.keras.layers.BatchNormalization()
        
        if 'dropout_rate' in config:
            self.dp1 = tf.keras.layers.Dropout(config['dropout_rate'])
            self.dp2 = tf.keras.layers.Dropout(config['dropout_rate'])
            self.dp3 = tf.keras.layers.Dropout(config['dropout_rate'])

        for layer in config['dense_layers']:
            self.dense1 = tf.keras.layers.Dense(layer['neurons'],activation=layer['activation'])
            self.bn1 = tf.keras.layers.BatchNormalization()
            self.dense2 = tf.keras.layers.Dense(layer['neurons'],activation=layer['activation'])
            self.bn2 = tf.keras.layers.BatchNormalization()
            
        self.out = tf.keras.layers.Dense(output_size,activation='sigmoid')
            
    
    def call(self,inputs,training=True,**kwargs):
        x = self.lstm(inputs)
        x = self.bn(x)
        
        if 'dropout_rate' in config:
            x = self.dp1(x)
        
        x = self.dense1(x)
        x = self.bn1(x)
        if 'dropout_rate' in config:
            x = self.dp2(x)

        x = self.dense2(x)
        x = self.bn2(x)
        if 'dropout_rate' in config:
            x = self.dp3(x)

        return self.out(x)

    # A convenient way to get model summary 
    # and plot in subclassed api
    def build_graph(self,raw_shape):
        x = tf.keras.layers.Input(shape=(None,raw_shape),ragged=True)
        return tf.keras.Model(inputs=[x],outputs=self.call(x))

构建和编译 mdoel

 s = MySubClassed(output_size=1)
 s.compile(
     loss = 'mse',metrics = ['mse'],optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))

传递一些张量以创建权重(检查)。

raw_input = (16,16,16)
y = s(tf.ones(shape=(raw_input))) 

print("weights:",len(s.weights))
print("trainable weights:",len(s.trainable_weights))

weights: 21
trainable weights: 15

总结和情节

总结并可视化模型图。

s.build_graph(16).summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None,None,16)]        0         
_________________________________________________________________
lstm (LSTM)                  (None,32)                6272      
_________________________________________________________________
batch_normalization (BatchNo (None,32)                128       
_________________________________________________________________
dropout (Dropout)            (None,32)                0         
_________________________________________________________________
dense_2 (Dense)              (None,32)                1056      
_________________________________________________________________
batch_normalization_3 (Batch (None,32)                128       
_________________________________________________________________
dropout_1 (Dropout)          (None,32)                0         
_________________________________________________________________
dense_3 (Dense)              (None,32)                1056      
_________________________________________________________________
batch_normalization_4 (Batch (None,32)                128       
_________________________________________________________________
dropout_2 (Dropout)          (None,32)                0         
_________________________________________________________________
dense_4 (Dense)              (None,1)                 33        
=================================================================
Total params: 8,801
Trainable params: 8,609
Non-trainable params: 192
tf.keras.utils.plot_model(
    s.build_graph(16),show_shapes=True,show_dtype=True,show_layer_names=True,rankdir="TB",)

enter image description here

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...