Numpy/PyTorch 有趣的张量积

问题描述

我有一个这样定义的 4 维火炬张量参数:

nn.parameter.Parameter(data=torch.Tensor((13,13,13)),requires_grad=True)

和四个带暗淡的张量 (batch_size,13)(或一个带暗淡的张量 (batch_size,4,13))。 我想得到一个带有暗淡(batch_size)等于这张图片末尾的公式的张量: [编辑:我在第一张图片中犯了一个错误,我已经更正了]

enter image description here

我在 Torch 文档中看到了 tensordot 函数,但我无法让它自己工作。

解决方法

每当你有一个有趣的张量积 torch.einsum(或 numpy.einsum)是你的朋友:

batch_size = 5
A = torch.rand(13,13,13)
a = torch.rand(batch_size,13)
b = torch.rand(batch_size,13)
c = torch.rand(batch_size,13)
d = torch.rand(batch_size,13)
B = torch.einsum('ijkl,bi,bj,bk,bl->b',A,a,b,c,d)