问题描述
我想以批处理方式计算特定维度上的成对串联。
例如
x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2,3,1])
我想得到 y
使得 y
是一个维度上所有向量对的串联,即:
y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])
y.shape = torch.Size([2,2])
因此,本质上,对于每个 x[i,:]
,您生成所有向量对,并将它们连接到最后一个维度。
有没有直接的方法来做到这一点?
解决方法
一种可能的方法是:
all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])
对张量进行整形后:
y = y.view(x.shape[0],x.shape[1],-1)
你得到:
y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])