如何在keras模型中为权重矩阵初始化可变张量?

问题描述

我正在尝试使用张量变量在keras层中用作权重。

我知道我可以改用numpy数组,但是我想输入张量的原因是我希望体重矩阵的类型为SparseTensor。

这是我到目前为止编写的一个小示例:

def model_keras(seed,new_hidden_size_list=None):

    number_of_layers = 1
    hidden_size = 512
    hidden_size_list = [hidden_size] * number_of_layers
    input_size = 784
    output_size = 10

    if new_hidden_size_list is not None:
        hidden_size_list = new_hidden_size_list

    weight_input = tf.Variable(tf.random.normal([784,512],mean=0.0,stddev=1.0))
    bias_input = tf.Variable(tf.random.normal([512],stddev=1.0))
    weight_output = tf.Variable(tf.random.normal([512,10],stddev=1.0))

    # This gives me an error when trying to use in kernel_initializer and       bias_initializer in the keras model
    weight_initializer_input = tf.initializers.variables([weight_input])
    bias_initializer_input = tf.initializers.variables([bias_input])
    weight_initializer_output = tf.initializers.variables([weight_output])

    # This works fine
    #weight_initializer_input = tf.initializers.lecun_uniform(seed=None)
    #bias_initializer_input = tf.initializers.lecun_uniform(seed=None)
    #weight_initializer_output = tf.initializers.lecun_uniform(seed=None)

    print(weight_initializer_input,bias_initializer_input,weight_initializer_output)

    model = keras.models.Sequential()
    for index in range(number_of_layers):
        if index == 0:
            # input layer
            model.add(keras.layers.Dense(hidden_size_list[index],activation=nn.selu,use_bias=True,kernel_initializer=weight_initializer_input,bias_initializer=bias_initializer_input,input_shape=(input_size,)))
        else:
             model.add(keras.layers.Dense(hidden_size_list[index],kernel_initializer=weight_initializer_hidden,bias_initializer=bias_initializer_hidden))

# output layer
    model.add(keras.layers.Dense(output_size,use_bias=False,kernel_initializer=weight_initializer_output))
    model.add(keras.layers.Activation(nn.softmax))

return model

我正在使用tensorflow 1.15。

任何想法都可以如何使用自定义用户定义的)张量变量作为初始值设定项,而不是预设方案(例如Glorot,Truncated normal等)。我可以采用的另一种方法是显式定义计算,而不是使用keras.Layer。

非常感谢

解决方法

启用急切执行后,您的代码即可工作。

print('Hello world')

将此添加到文件顶部。

有关有效代码,请参见this