Pytorch 成对连接张量

问题描述

我想以批处理方式计算特定维度上的成对串联。

例如

x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2,3,1])

我想得到 y 使得 y一个维度上所有向量对的串联,即:

y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])

y.shape = torch.Size([2,2])

因此,本质上,对于每个 x[i,:],您生成所有向量对,并将它们连接到最后一个维度。 有没有直接的方法来做到这一点?

解决方法

一种可能的方法是:

    all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
    y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])

对张量进行整形后:

y = y.view(x.shape[0],x.shape[1],-1)

你得到:

y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])