没有为任何变量提供梯度:...但我的损失似乎是可微的?

问题描述

我参加了一个具有奇怪损失指标的比赛,在我看来它可以在 tf 中实现,所以我这样做了:

def my_loss(y_true,y_pred):
        y_pred = tf.cast(tf.expand_dims(tf.argmax(y_pred,axis = 1),tf.float32)
        print(y_pred.shape,y_true.shape)
        dim = 1024 # NONE VALUES NOT SUPPORTED tf.cast(y_true.shape[0],tf.int32)
        delay_bins = y_true[:,0]
        delay_bins_all = tf.dtypes.cast(tf.tile(tf.expand_dims(tf.range(8),axis = 0),[dim,1]),tf.int64)
        delay_bins_squares = tf.dtypes.cast(tf.tile(tf.expand_dims(delay_bins,[1,8]),tf.int64) # TensorShape([10,10])
        to_mul =  tf.dtypes.cast(tf.equal(delay_bins_squares,delay_bins_all),tf.int32) # TensorShape([10,10])
        n_train = tf.expand_dims(tf.constant([480251,206240,70795,31476,17528,10965,7159,20782]),axis = 1)
       
        n_samples = tf.cast(tf.linalg.matmul(to_mul,n_train),tf.float32)[:,0])
    
        res = tf.reduce_sum((y_true - y_pred) ** 2 / n_samples) / 8.0
        return res

简而言之,这是多类预测问题,但不是使用分类交叉熵或 MSE 评估问题(类实际上具有序数意义),我坚持使用这个指标,它以某种方式计算由样本数量考虑的 MSE每个班级。

这个损失函数做了我想要的,但是当我将它与模型一起使用时,我得到这个错误

ValueError: 没有为任何变量提供梯度:['batch_normalization_54/gamma:0','batch_normalization_54/beta:0','dense_258/kernel:0','dense_258/bias:0','dense_259/kernel: 0'、'dense_259/bias:0'、'dense_260/kernel:0'、'dense_260/bias:0'、'dense_261/kernel:0'、'dense_261/bias:0'、'layer_normalization_52/gamma:0','layer_normalization_52/beta:0','dense_262/kernel:0','dense_262/bias:0']。

这是我使用的模型:

n_cols = x_train.shape[1]
inp_layer = tfl.Input((n_cols))
inp_layer_norm = tfl.Batchnormalization()(inp_layer)
dense = tfl.Dense(512,activation = 'relu') (inp_layer_norm)
dense = tfl.Dense(256,activation = 'relu') (dense)
dense = tfl.Dense(128,activation = 'relu') (dense)
dense = tfl.Dense(64,activation = 'relu') (dense)
dense = tfl.Layernormalization()(dense)
out = tfl.Dense(8,activation = 'sigmoid') (dense)
model = tf.keras.Model(inputs = inp_layer,outputs = out)
model.summary()
# 'kullback_leibler_divergence' works quite well.
model.compile(optimizer = 'adam',loss = my_loss)

你认为我做错了什么?在我看来,我没有使用任何非差分操作。

感谢您的帮助。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...