如何在我的 Pytorch 中添加我自己的自定义激活函数

问题描述

如何创建如下所示的自定义激活函数

    def forward(self,x):
        x = self.l1(x)
        x = activationFunction(x)
        x = self.l2(x)
        x = activationFunction(x)
        return x

def activationFunction(x):
    if x <= 0:
        return 1.359140915 * math.exp(x - 1)
    elif x > 15:
        return 1 - 1/(109.0858178 * x - 1403.359435)
    else:
        return 0.03 * math.log(1000000 * x + 1) + 0.5

这就是我现在得到的:

RuntimeError: 具有多个值的 Tensor 的布尔值不明确

编辑:

def forward(self,x):
    x = self.l1(x)
    x = torch.where(x <= 0,1.359140915 * (x-1).exp(),torch.where(x > 15,1 - 1/(109.0858178 * x - 1403.359435),0.03 * (1000000 * x + 1).log() + 0.5))
    x = self.l2(x)
    x = torch.where(x <= 0,0.03 * (1000000 * x + 1).log() + 0.5))
    return x

这相对来说效果很好。

解决方法

由于 tensor.apply_ 不合适,你必须用 x0 = 1.359140915 * (x - 1).exp() 之类的张量表达式重写函数,条件变成 torch.where( x>0,x0,x1 ) 之类的东西。