NumPy Tensordot 轴 = 2

问题描述

我知道有很多关于 tensordot 的问题,我已经浏览了一些我确信人们花了数小时制作的 15 页迷你书答案中的一些,但我还没有找到解释{{1}

这让我想到 axes=2,但作为一个数组:

np.tensordot(b,c,axes=2) == np.sum(b * c)

但是这失败了:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,axes=2)
Out: array(7532)

如果有人能提供对 a = np.arange(30).reshape((2,3,5)) np.tensordot(a,a,axes=2) 的简短说明,而且只有 np.tensordot(x,y,axes=2),那么我很乐意接受。

解决方法

In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
  File "<ipython-input-71-dbe04e46db70>",line 1,in <module>
    np.tensordot(a,axes=2)
  File "<__array_function__ internals>",line 5,in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py",line 1116,in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

在我之前的帖子中,我推断 axis=2 翻译为 axes=([-2,-1],[0,1])

How does numpy.tensordot function works step-by-step?

In [72]: np.tensordot(a,axes=([-2,1]))
Traceback (most recent call last):
  File "<ipython-input-72-efdbfe6ff0d3>",1]))
  File "<__array_function__ internals>",in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

所以这是尝试对第一个 a 的最后两个维度和第二个 a 的前两个维度进行双轴缩减。有了这个 a,这是尺寸不匹配。显然这个 axes 是为 2d 数组设计的,没有考虑 3d 数组。这不是 3 轴缩小。

某些开发人员认为这些单位数轴值很方便,但这并不意味着它们经过严格的考虑或测试。

元组轴为您提供更多控制:

In [74]: np.tensordot(a,axes=[(0,1,2),(0,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,1),1)])
Out[75]: 
array([[ 880,940,1000,1060],[ 940,1006,1072,1138],[1000,1144,1216],[1060,1138,1216,1294]])