问题描述
我有一些矩阵,其中行属于某个标签,无序。我想对每个标签的所有行求和。
以下是通过循环完成的方法:
labels = torch.tensor([0,1,0])
x = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
torch.stack([torch.sum(x[labels == i],dim=0) for i in torch.unique(labels)])
所需的输出:
tensor([[ 8,10,12],[ 4,6]])
编辑:为了清楚起见,我有标签张量,我知道哪些标签重复,我有兴趣在不使用循环的情况下计算最后一行。我以为scatter_add_
或gather
可能会有所帮助。
解决方法
- 1:我试图找到重复的标签
def get_repeated_labels(label_list):
"""
Args:
label_list (ndarray): target list
Return:
(list): Repeated labels
"""
records_array = label_list
values,inverse,count = np.unique(records_array,return_inverse=True,return_counts=True)
repeated = np.where(count > 1)[0]
repeated = values[repeated]
rows,cols = np.where(inverse == repeated[:,np.newaxis])
_,inverse_rows = np.unique(rows,return_index=True)
res = np.split(cols,inverse_rows[1:])
return res
if __name__ == '__main__':
labels = torch.tensor([0,1,0])
r = get_repeated_labels(labels.numpy())
输出:
[array([0,2])]
这意味着第0个索引和第2个索引正在重复。我们需要将第0和第2个索引数组求和。
torch.sum(x[r[i]],dim=0)
但是len(r[i])
是1维的,我们有两个标签。因此,我使用了if-else条件。
最终:
print(torch.stack([torch.sum(x[r[i]],dim=0) if len(r) >= i + 1 else x[i] for i,_ in enumerate(torch.unique(labels))]))
输出:
tensor([[ 8,10,12],[ 4,5,6]])