问题描述
我正在尝试创建一个新的激活层,我们称之为 topk,其工作方式如下。它将把一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并加上偏置的结果)和一个正整数 k 并输出一个大小为 n 的向量 topk(x),其元素是:
x_i (if x_i is one of the top k elements of x)
topk(x)_i =
0 (otherwise)
在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。
我应该如何实现这一点?任何帮助将不胜感激。
解决方法
您可以使用 torch.topk
:
k = 2
output = torch.randn(5)
vals,idx = output.topk(k)
topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557,0.0000,1.4562,0.0000])
请注意,虽然 'values'
的 topk()
是可微的,但 'indices'
are not(类似于 argmax 不是可微函数的方式)。