焦点损失实现

问题描述

在介绍焦点损失的 paper 中,他们声明损失函数的公式如下:

enter image description here

哪里

enter image description here

我在另一个作者的 Github 页面上找到了它的实现,他在他们的 paper 中使用了它。我在我拥有的分割问题数据集上尝试了该功能,它似乎工作得很好。

以下是实现:

def binary_focal_loss(pred,truth,gamma=2.,alpha=.25):
    eps = 1e-8
    pred = nn.softmax(1)(pred)
    truth = F.one_hot(truth,num_classes = pred.shape[1]).permute(0,3,1,2).contiguous()

    pt_1 = torch.where(truth == 1,pred,torch.ones_like(pred))
    pt_0 = torch.where(truth == 0,torch.zeros_like(pred))

    pt_1 = torch.clamp(pt_1,eps,1. - eps)
    pt_0 = torch.clamp(pt_0,1. - eps)

    out1 = -torch.mean(alpha * torch.pow(1. - pt_1,gamma) * torch.log(pt_1)) 
    out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))

    return out1 + out0

我不明白的部分是pt_0和pt_1的计算。我为自己创建了一个小例子来尝试弄清楚,但它仍然让我感到困惑。

# one hot encoded prediction tensor
pred = torch.tensor([
                     [
                      [.2,.7,.8],# probability
                      [.3,.5,.7],# of
                      [.2,.6,.5]  # background class
                     ],[
                      [.8,.3,.2],# probability
                      [.7,.3],# of
                      [.8,.4,.5]  # class 1
                     ]
                    ])

# one-hot encoded ground truth labels
truth = torch.tensor([
                      [1,0],[1,0]
                     ])
truth = F.one_hot(truth,num_classes = 2).permute(2,1).contiguous()

print(truth)
# gives me:
# tensor([
#         [
#          [0,1],#          [0,1]
#         ],#         [
#          [1,#          [1,0]
#         ]
#       ])

pt_0 = torch.where(truth == 0,torch.zeros_like(pred))
pt_1 = torch.where(truth == 1,torch.ones_like(pred))

print(pt_0)
# gives me:
# tensor([[
#         [0.2000,0.0000,0.0000],#         [0.3000,0.5000,#         [0.2000,0.0000]
#         ],#        [
#         [0.0000,0.3000,0.2000],#         [0.0000,0.3000],0.4000,0.5000]
#        ]
#      ])

print(pt_1)
# gives me:
# tensor([[
#          [1.0000,0.7000,0.8000],#          [1.0000,1.0000,0.7000],0.6000,0.5000]
#         ],#         [
#          [0.8000,1.0000],#          [0.7000,#          [0.8000,1.0000]
#         ]
#       ])

我不明白为什么在 pt_0 中我们在 torch.where 语句为假的地方放置零,而在 pt_1 中放置零。从我对论文的理解来看,我会认为不是放置零或一,而是放置 1-p。

谁能帮我解释一下?

解决方法

因此,您尝试理解的部分是人们通常在想要将不需要的额外计算归零时执行的程序。

再看看pt的公式:

enter image description here

下面的代码正是通过分离这两个条件来做到这一点的:

# if y=1
pt_1 = torch.where(truth == 1,pred,torch.ones_like(pred))
# otherwise
pt_0 = torch.where(truth == 0,torch.zeros_like(pred)) 

pt_0 中设置为零,在 pt_1 中设置为 1 将导致输出为零,因此对贡献损失值没有影响,即:

# Because pow(0.,gamma) == 0. and log(1.) == 0.
# out1 == 0. if pt_1 == 1.
out1 = -torch.mean(alpha * torch.pow(1. - pt_1,gamma) * torch.log(pt_1))
# out0 == 0. if pt_0 == 0.
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))

pt_0 使用 p 而不是 1-p 的值的原因与您上一个问题的原因相同,即:

1 - (1 - p) == 1 - 1 + p == p

所以它以后可以通过以下方式计算FL(pt)

# -a * pow(1 - (1 - p),gamma )* log(1 - p) == -a * pow(p,gamma )* log(1 - p)
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))