numpy einsum计算沿轴的外积

问题描述

我有两个包含兼容矩阵的numpy数组,并且想计算使用numpy.einsum的元素明智的外部乘积。数组的形状为:

A1 = (i,j,k)
A2 = (i,k,j) 

因此,数组分别包含i形状为(k,j)(j,k)的矩阵。

因此,假设A1将包含矩阵A,B,C,而A2将包含矩阵D,E,F,结果将是:

A3 = (A(x)D,B(x)E,C(x)F)

(x)是外部乘积运算符。

根据我的理解,这将基于this answer以下形状的数组A3

A3 = (i,j*k,j*k)

到目前为止,我已经尝试过:

np.einsum("ijk,ilm -> ijklm",A1,A2)

但是生成的形状不能正确拟合。

作为健全性检查,我正在对此进行测试:

A = np.asarray(([1,2],[3,4]))
B = np.asarray(([5,6],[7,8]))

AB_outer = np.outer(A,B)

A_vec = np.asarray((A,A))
B_vec = np.asarray((B,B))

# this line is not correct
AB_vec = np.einsum("ijk,A_vec,B_vec)

np.testing.assert_array_equal(AB_outer,AB_vec[0])

当前这会引发一个断言错误,因为我的einsum表示法不正确。我也乐于接受任何可以解决这个问题的建议,这些建议比若虫einsum更快或同样快。

解决方法

我们可以延长暗淡的时间,让enter image description here为我们完成工作-

(A1[:,:,None,None]*A2[:,:]).swapaxes(2,3)

样品运行-

In [46]: A1 = np.random.rand(3,4,4)
    ...: A2 = np.random.rand(3,4)

In [47]: out = (A1[:,3)

In [48]: np.allclose(np.multiply.outer(A1[0],A2[0]),out[0])
Out[48]: True

In [49]: np.allclose(np.multiply.outer(A1[1],A2[1]),out[1])
Out[49]: True

In [50]: np.allclose(np.multiply.outer(A1[2],A2[2]),out[2])
Out[50]: True

broadcasting等效的是-

np.einsum('ijk,ilm->ijklm',A1,A2)
,

您可以计算运行结果:

result = np.einsum('ijk,ikl->ijl',A2)

我在以下测试数据上检查了上面的代码:

A = np.arange(1,13).reshape(3,-1)
B = np.arange(2,14).reshape(3,-1)
C = np.arange(3,15).reshape(3,-1)
D = np.arange(1,13).reshape(4,-1)
E = np.arange(2,14).reshape(4,-1)
F = np.arange(3,15).reshape(4,-1)
A1 = np.array([A,B,C])
A2 = np.array([D,E,F])

结果是:

array([[[ 70,80,90],[158,184,210],[246,288,330]],[[106,120,134],[210,240,270],[314,360,406]],[[150,168,186],[270,304,338],[390,440,490]]])

现在计算3个“部分结果”:

res_1 = A @ D
res_2 = B @ E
res_3 = C @ F

并检查它们是否与结果的连续部分相同。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...