如何在不干扰 TF>2.0 模型图的情况下使用张量评估实现自定义 Keras 序数损失函数?

问题描述

我正在尝试使用 Keras 后端在 Tensorflow 2.4 中实现自定义损失函数

损失函数为ranking loss;我发现以下论文有一些对数似然损失:Chen et al. Single-Image Depth Perception in the Wild.

同样,我想从图像中采样一些(在本例中为 50 个)点,以使用 NYU-Depth 数据集比较地面实况和预测深度图之间的相对顺序。作为 Numpy 的粉丝,我开始使用它,但遇到了以下异常:

ValueError: No gradients provided for any variable: [...]

我了解到这是由于调用损失函数时没有填充参数引起的,而是编译了一个 C 函数,然后稍后使用。因此,虽然我知道张量的维度(4、480、640、1),但我无法根据需要处理数据,并且必须在顶部使用 keras.backend 函数,以便最终(如果我理解正确的话),来自 TF 图的输入张量和输出张量之间应该有一条路径,该路径必须提供梯度。

所以我现在的问题是:这是 keras 中可行的损失函数吗? 我已经用我的原始代码的不同变体尝试了一些想法和不同的方法,例如:

def ranking_loss_function(y_true,y_pred):
    # Chen et al. loss

    y_true_np = K.eval(y_true)
    y_pred_np = K.eval(y_pred)
    
    if y_true_np.shape[0] != None:
        num_sample_points = 50
        total_samples = num_sample_points ** 2
        
        err_list = [0 for x in range(y_true_np.shape[0])]

        for i in range(y_true_np.shape[0]):
            sample_points = create_random_samples(y_true,y_pred,num_sample_points)
            for x1,y1 in sample_points:
                for x2,y2 in sample_points:
                    if y_true[i][x1][y1] > y_true[i][x2][y2]:
                        #image_relation_true = 1
                        err_list[i] += np.log(1 + np.exp(-1 * y_pred[i][x1][y1] + y_pred[i][x2][y2]))
                    elif y_true[i][x1][y1] < y_true[i][x2][y2]:
                        #image_relation_true = -1
                        err_list[i] += np.log(1 + np.exp(y_pred[i][x1][y1] - y_pred[i][x2][y2]))
                    else:
                        #image_relation_true = 0
                        err_list[i] += np.square(y_pred[i][x1][y1] - y_pred[i][x2][y2])
        
        err_list = np.divide(err_list,total_samples)
        
        return K.constant(err_list)

如您所知,主要思想是首先创建样本点,然后根据 y_true/y_pred 中它们之间的现有关系继续引用论文中的相应计算。

谁能帮助我并提供一些有关如何使用 keras.backend 函数正确实现此损失的有用信息或提示?与标准回归损失相比,试图包含序数关系信息确实让我感到困惑。

编辑:以防万一这会引起混淆:create_random_samples() 仅基于 shape[1] 和 {{1} 创建 50 个随机样本点 (x,y) 坐标对} of shape[2]图片宽高)

EDIT(2):在 GitHub 上找到 this 变体后,我尝试了一个仅使用 TF 函数从张量中检索数据并计算输出的变体。调整后的可能更正确的版本仍然抛出相同的异常:

y_true

EDIT(3): 这是 create_random_samples() 的代码: (仅供参考:因为在这种情况下从 y_true 获取形状很奇怪,所以我首先在这里对其进行硬编码,因为我知道它用于我目前使用的数据集。)

def ranking_loss_function(y_true,y_pred):
#In the Wild ranking loss
y_true_np = K.eval(y_true)
y_pred_np = K.eval(y_pred)

if y_true_np.shape[0] != None:
    num_sample_points = 50
    total_samples = num_sample_points ** 2
    
    bs = y_true_np.shape[0]
    w = y_true_np.shape[1]
    h = y_true_np.shape[2]
    
    total_samples = total_samples * bs
    num_pairs = tf.constant([total_samples],dtype=tf.float32)
    
    output = tf.Variable(0.0)

    for i in range(bs):
        sample_points = create_random_samples(y_true,num_sample_points)
        for x1,y1 in sample_points:
            for x2,y2 in sample_points:
            
                y_true_sq = tf.squeeze(y_true)
                y_pred_sq = tf.squeeze(y_pred)
            
                d1_t = tf.slice(y_true_sq,[i,x1,y1],[1,1,1])
                d2_t = tf.slice(y_true_sq,x2,y2],1])
                d1_p = tf.slice(y_pred_sq,1])
                d2_p = tf.slice(y_pred_sq,1])

                d1_t_sq = tf.squeeze(d1_t)
                d2_t_sq = tf.squeeze(d2_t)
                d1_p_sq = tf.squeeze(d1_p)
                d2_p_sq = tf.squeeze(d2_p)
                
                if d1_t_sq > d2_t_sq:
                    # --> Image relation = 1
                    output.assign_add(tf.math.log(1 + tf.math.exp(-1 * d1_p_sq + d2_p_sq)))
                elif d1_t_sq < d2_t_sq:
                    # --> Image relation = -1
                    output.assign_add(tf.math.log(1 + tf.math.exp(d1_p_sq - d2_p_sq)))
                else:
                    output.assign_add(tf.math.square(d1_p_sq - d2_p_sq))
    
    return output/num_pairs

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...