从没有重复的张量中选择顶部 K 值

问题描述

torch.Tensor.topk 提供了一种有效的方法来沿一个维度提取张量中的前 k 个值。是否可以将前 k 个值限制为不重复

例如

input = torch.tensor([0.2,0.2,0.1])
k = 2
dim = 0


output[0] = torch.tensor([0.2,0.1])
output[1] = torch.longtensor([0,2])

解决方法

您可以在输入张量上应用 torch.unique

>>> input.unique().topk(k=2).values
tensor([0.2000,0.1000])

请注意,此时您将丢失索引。


编辑:实际上 torch.unique 有一个对结果进行排序的选项(默认情况下该选项处于启用状态)。

>>> input
tensor([0.0000,0.3000,0.2000,0.1000])

>>> input.unique(return_inverse=True)[1].unique(sorted=False)
tensor([1,2,3,0])