将张量切片乘以矩阵行

问题描述

我试图找到一种方法来有效地计算张量 (shape: (n,n,m)) 的每个深度二维切片与矩阵 (shape: (n,m) )。我正在尝试做的事情看起来像这样没有矢量化:

import numpy as np

np.random.seed(42)

np.random.seed(1)

a = np.arange(16).reshape(4,4)
b = np.random.randn(4,4,4)
c = np.zeros((4,4))

for i in range(4):
    c[i] = b[...,i] @ a[i]

产生的结果:

<<< print(c)
>>> [[  0.53623421  -0.10257152  -1.34855819  -1.72774519]
     [-18.13932187   1.82230599 -11.99348739  15.0787884 ]
     [ 38.5704751   -0.38514407   4.19673794   9.01941574]
     [-68.11165212  -5.52586601  64.69279036  11.3196871 ]]

我最接近的是:

<<< print(np.einsum("ij,ijk->ki",a,b))
>>> [[  0.53623421,-2.66288958,-16.91264496,-12.98103047],[ -3.95244251,1.82230599,-20.5351456,34.69343339],[  8.07033597,-0.90215803,4.19673794,12.57858867],[ -8.18116212,-3.54815874,46.60443317,11.3196871 ]]

至少左上和右下元素匹配的地方。

解决方法

你们很亲近。

以下应该产生与您的 for 循环相同的结果

print(np.einsum('ikj,jk->ji',b,a))

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...