问题描述
我阅读了一些使用“自举交叉熵损失”训练其分割网络的论文。这个想法是只关注最困难的k%(例如15%)像素,以提高学习性能,尤其是在简单像素占主导的情况下。
当前,我正在使用标准交叉熵:
loss = F.binary_cross_entropy(mask,gt)
如何在PyTorch中将其有效地转换为自举版本?
解决方法
通常,我们还会为损失增加一个“热身”时间,以便网络可以先学会适应易发区域,然后过渡到较难的区域。
此实现从k=100
开始,持续进行20000次迭代,然后将其线性衰减到k=15
,再进行50000次迭代。
class BootstrappedCE(nn.Module):
def __init__(self,start_warm=20000,end_warm=70000,top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self,input,target,it):
if it < self.start_warm:
return F.cross_entropy(input,target),1.0
raw_loss = F.cross_entropy(input,reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss,_ = torch.topk(raw_loss,int(num_pixels * this_p),sorted=False)
return loss.mean(),this_p
,
添加@hkchengrex的自我回答(用于将来与PyTorch进行自我和API奇偶校验);
可以像这样首先实现functional
版本(在original torch.nn.functional.cross_entropy
中提供一些附加参数)(我也更喜欢reduction
为callable
而不是预定义的字符串):
import typing
import torch
def bootstrapped_cross_entropy(
inputs,targets,iteration,p: float,warmup: typing.Union[typing.Callable[[float,int],float],int] = -1,weight=None,ignore_index=-100,reduction: typing.Callable[[torch.Tensor],torch.Tensor] = torch.mean,):
if not 0 < p < 1:
raise ValueError("p should be in [0,1] range,got: {}".format(p))
if isinstance(warmup,int):
this_p = 1.0 if iteration < warmup else p
elif callable(warmup):
this_p = warmup(p,iteration)
else:
raise ValueError(
"warmup should be int or callable,got {}".format(type(warmup))
)
# Shortcut
if this_p == 1.0:
return torch.nn.functional.cross_entropy(
inputs,weight,ignore_index=ignore_index,reduction=reduction
)
raw_loss = torch.nn.functional.cross_entropy(
inputs,weight=weight,reduction="none"
).view(-1)
num_pixels = raw_loss.numel()
loss,sorted=False)
return reduction(loss)
也可以将warmup
指定为callable
(采用p
和当前的iteration
)或int
来进行灵活或轻松的调度。
并在每次调用期间使基于_WeightedLoss
和iteration
的类自动增加(因此只需传递inputs
和targets
):
class BoostrappedCrossEntropy(torch.nn.modules.loss._WeightedLoss):
def __init__(
self,):
self.p = p
self.warmup = warmup
self.ignore_index = ignore_index
self._current_iteration = -1
super().__init__(weight,size_average=None,reduce=None,reduction=reduction)
def forward(self,inputs,targets):
self._current_iteration += 1
return bootstrapped_cross_entropy(
inputs,self._current_iteration,self.p,self.warmup,self.weight,self.ignore_index,self.reduction,)