问题描述
我正在使用以下代码通过pytorch查找topk匹配项。
def find_top(self,x,y,n_neighbors,unit_vectors=False,cuda=False):
if not unit_vectors:
x = __to_unit_torch__(x,cuda=cuda)
y = __to_unit_torch__(y,cuda=cuda)
with torch.no_grad():
d = 1. - torch.matmul(x,y.transpose(0,1))
values,indices = torch.topk(d,dim=1,largest=False,sorted=True)
return indices.cpu().numpy()
不幸的是,它引发了以下错误
values,sorted=True)
RuntimeError: invalid argument 5: k not in range for dimension at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:23
d的大小为(1793,1)
。我想念什么?
解决方法
This error 当您使用大于类总数的 torch.topk
调用 k
时发生。减少你的争论,它应该运行良好。