是否可以使用评估指标如 NDCG作为损失函数?

问题描述

我正在研究一个名为 DPR 的信息检索模型,它基本上是一个神经网络(2 个 BERT),可以根据查询对文档进行排名。目前,该模型以二进制方式(文档是否相关)进行训练,并使用负对数似然(NLL)损失。我想改变这种二元行为并创建一个可以处理分级相关性的模型(例如 3 个等级:相关、以某种方式相关、不相关)。我必须更改损失函数,因为目前,我只能为每个查询分配 1 个正目标(DPR 使用 pytorch NLLLoss),这不是我需要的。

我想知道是否可以使用 NDCG(标准化折扣累积增益)等评估指标来计算损失。我的意思是,损失函数的全部意义在于说明我们的预测有多差,而 NDCG 也在做同样的事情。

那么,我可以使用这些指标来代替损失函数并进行一些修改吗?在 NDCG 的情况下,我认为从 1 (1 - NDCG_score) 中减去结果可能是一个很好的损失函数。是真的吗?

最好的问候,阿里。

解决方法

是的,这是可能的。您可能希望应用 listwise 学习排序方法,而不是更标准的成对损失函数。

在成对损失中,网络提供了示例对(rel、非 rel),并且真实标签是二元的(如果对中的第一个相关,则为 1,否则为 0)。

>

然而,在列表学习方法中,在训练期间,您将提供一个列表而不是一对,并且真实值(仍然是二进制)将指示此排列是否确实是最佳排列,例如最大化 nDCG 的那个。在列表方法中,排名目标因此被转换为排列的分类

有关详细信息,请参阅此 paper

显然,网络可能会采用 BERT 查询向量和列表中的文档,而不是将特征作为输入,类似于 ColBERT。不像 ColBERT,你从 2 个文档中输入向量(成对训练),对于列表训练,你需要从 5 个文档中输入向量。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...