PyTorch:除顶部 k 之外的所有向量元素都归零?

问题描述

我正在尝试创建一个新的激活层,我们称之为 topk,其工作方式如下。它将把一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并加上偏置的结果)和一个正整数 k 并输出一个大小为 n 的向量 topk(x),其元素是:

              x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
              0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。

我应该如何实现这一点?任何帮助将不胜感激。

解决方法

您可以使用 torch.topk

k = 2
output = torch.randn(5)
vals,idx = output.topk(k)

topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557,0.0000,1.4562,0.0000])

请注意,虽然 'values'topk() 是可微的,但 'indices' are not(类似于 argmax 不是可微函数的方式)。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...