用 numpy 减少循环 关于结果的差异:关于运行时:

问题描述

我们正在尝试实现给定的 Modified Gram Schmidt 算法:

instructions here

我们首先尝试以下面的方式实现第 5-7 行:

for j in range(i+1,N):
    R[i,j] = np.matmul(Q[:,i].transpose(),U[:,j])
    u = U[:,j] - R[i,j] * Q[:,i]
    U[:,j] = u

为了减少运行时间,我们尝试用这样的矩阵运算替换循环:

# we changed the inner loop to matrix operations in order to improve running time
R[i,i + 1:] = np.matmul(Q[:,i],i + 1:])
U[:,i + 1:] = U[:,i + 1:] - R[i,i + 1:] * np.transpose(np.tile(Q[:,(N - i - 1,1)))

结果不一样,但非常相似。我们的二审有问题吗?

谢谢!

编辑: 完整的功能是:

def gram_schmidt2(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,R
    """
    N = np.shape(A)[0]
    U = A.copy()
    Q = np.zeros((N,N),dtype=np.float64)
    R = np.zeros((N,dtype=np.float64)
    for i in range(N):
        R[i,i] = np.linalg.norm(U[:,i])
        # Handling devision by zero by exiting the program as was advised in the forum
        if R[i,i] == 0:
            zero_devision_error(gram_schmidt._name_)
        Q[:,i] = np.divide(U[:,R[i,i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        for j in range(i+1,N):
            R[i,j])
            u = U[:,i]
            U[:,j] = u
    return Q,R

和:

def gram_schmidt1(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        R[i,i + 1:])
        U[:,1)))
    return Q,R

当我们在矩阵上运行函数时:

[[ 1.00000000e+00 -1.98592571e-02 -1.00365698e-04 -1.45204974e-03
  -9.95711793e-01 -1.77405377e-04 -7.68526195e-03]
 [-1.98592571e-02  1.00000000e+00 -1.77809186e-02 -1.55937174e-01
  -9.80881385e-03 -2.05317715e-02 -2.01456899e-01]
 [-1.00365698e-04 -1.77809186e-02  1.00000000e+00 -1.87979660e-01
  -5.12368040e-05 -8.35323206e-01 -4.59007949e-05]
 [-1.45204974e-03 -1.55937174e-01 -1.87979660e-01  1.00000000e+00
  -8.69848133e-04 -3.64095785e-01 -5.55408776e-04]
 [-9.95711793e-01 -9.80881385e-03 -5.12368040e-05 -8.69848133e-04
   1.00000000e+00 -9.54867422e-05 -5.92716161e-03]
 [-1.77405377e-04 -2.05317715e-02 -8.35323206e-01 -3.64095785e-01
  -9.54867422e-05  1.00000000e+00 -5.55505343e-05]
 [-7.68526195e-03 -2.01456899e-01 -4.59007949e-05 -5.55408776e-04
  -5.92716161e-03 -5.55505343e-05  1.00000000e+00]]

我们得到不同的输出

对于克 shmidt 1:

问:

[[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
  -4.91879501e-02 -4.90769704e-01  1.58268518e-01]
 [-2.78569770e-04  7.14001661e-01 -2.70586659e-03 -2.70735367e-02
   5.78840577e-01  2.37376069e-01  1.97835647e-02]
 [-2.48309244e-03 -2.34709092e-03  7.38351181e-01  2.63187853e-01
  -3.35473487e-01  3.38823696e-01  3.36320600e-01]
 [-4.27658449e-03 -2.12584453e-03 -6.70730760e-01  3.82666405e-01
  -3.44451231e-01  3.46085878e-01 -7.71559024e-01]
 [-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
   5.94568750e-01  2.38329853e-01 -2.76969906e-01]
 [-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
  -2.78679804e-01  2.78781202e-01  0.00000000e+00]
 [-6.72739327e-01  1.73894101e-04  2.25707383e-03  1.69052581e-02
  -1.26723666e-02 -5.77668322e-01 -4.35238424e-01]]

R:

[[ 1.36233007e+00  1.11436069e-03  1.04418015e-02  1.27072186e-02
   1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
 [ 0.00000000e+00  1.40055740e+00  5.29057231e-04  1.44628716e-03
  -1.40014587e+00  3.57535802e-04  2.25417515e-03]
 [ 0.00000000e+00  0.00000000e+00  1.35440586e+00 -1.33059602e+00
   6.67148806e-04 -3.51561140e-02  2.23809829e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.81147599e-01
   1.33951520e-02 -9.55057795e-01  2.36910667e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   3.37143743e-02 -1.97436093e-01  7.90539705e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  3.40545951e-01 -1.75971454e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  3.50740324e-16]]

对于克 shmidt 2:

问:

    [[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
  -4.91879501e-02 -4.90769704e-01  4.55677949e-01]
 [-2.78569770e-04  7.14001661e-01 -2.70586659e-03 -2.70735367e-02
   5.78840577e-01  2.37376069e-01 -1.89865812e-01]
 [-2.48309244e-03 -2.34709092e-03  7.38351181e-01  2.63187853e-01
  -3.35473487e-01  3.38823696e-01  9.49329061e-02]
 [-4.27658449e-03 -2.12584453e-03 -6.70730760e-01  3.82666405e-01
  -3.44451231e-01  3.46085878e-01 -4.36691368e-01]
 [-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
   5.94568750e-01  2.38329853e-01 -1.13919487e-01]
 [-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
  -2.78679804e-01  2.78781202e-01 -1.51892650e-01]
 [-6.72739327e-01  1.73894101e-04  2.25707383e-03  1.69052581e-02
  -1.26723666e-02 -5.77668322e-01 -7.21490087e-01]]

R:

[[ 1.36233007e+00  1.11436069e-03  1.04418015e-02  1.27072186e-02
   1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
 [ 0.00000000e+00  1.40055740e+00  5.29057231e-04  1.44628716e-03
  -1.40014587e+00  3.57535802e-04  2.25417515e-03]
 [ 0.00000000e+00  0.00000000e+00  1.35440586e+00 -1.33059602e+00
   6.67148806e-04 -3.51561140e-02  2.23809829e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.81147599e-01
   1.33951520e-02 -9.55057795e-01  2.36910667e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   3.37143743e-02 -1.97436093e-01  7.90539705e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  3.40545951e-01 -1.75971454e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  3.65463051e-16]]

解决方法

以下代码以更有效的方式执行您想要的操作:

        Q_i = Q[:,i].reshape(1,-1)
        R[i,i+1:] = np.matmul(Q_i,U[:,i+1:])
        U[:,i+1:] -=  np.multiply(R[i,i+1:],Q_i.T)

第一行只是为了方便,使代码更具可读性。

除了最后一行之外,一切都与您的原始提案相同。最后一行执行逐元素乘法,这最终是您在内循环的最后一行中所做的。

关于结果的差异:

你的代码没问题,两者都是一样的。在处理浮点数时,不应测试为 A == B。相反,我建议您检查两个数组的不同之处。

特别是跑步

Q1,R1 = gram_schmidt2(A)
Q2,R2 = gram_schmidt1(A)

(Q1 - Q2).mean()
(R1 - R2).mean()

分别给出:

-5.4997372770547595e-09 and -5.2465803662044656e-18

已经非常接近于 0。 1e-18 低于 dtype np.float64 的错误,所以你很好。

如果您运行差异 3*0.1 - 0.3(约 1e-17),您可以检查这一点

矩阵 Q 的误差较大,因为它来自浮点数之间的除法,如果矩阵元素的量级较小(这里有时就是这种情况),则会增加误差。

关于运行时:

在运行您的代码的两个版本时,我得到了相似的运行时间:(243 µs ± 25.5 µs 使用循环,241 µs ± 6.82 µs 使用您的第二个版本);而此处提供的代码实现了 152 µs ± 1.49 µs

,

我建议您使用 Numba,它是一个出色的速度优化器,通过将许多 Python 程序 JIT 编译为 C++ 和机器代码,它可以将许多 Python 程序提升 50-200 倍。

要安装 numba,只需执行一次 python -m pip install numba

以下是将您的算法应用于 numba 的代码,主要是在第一行函数之前只是一个 @numba.njit 装饰器。

在 numba 代码中,您可以只编写常规 Python 循环和任何数学计算,即使不使用 Numpy,您的最终代码也会非常快,大多数情况下甚至比任何 Numpy 代码都快。

我使用您的 gram_schmidt2() 函数作为基础,仅将 np.multiply() 替换为 np.dot(),因为 Numba 似乎仅实现了 np.dot() 功能。

Try it online!

import numpy as np,numba

@numba.njit(cache = True,fastmath = True,parallel = True)
def gram_schmidt2(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,R
    """
    N = np.shape(A)[0]
    U = A.copy()
    Q = np.zeros((N,N),dtype=np.float64)
    R = np.zeros((N,dtype=np.float64)
    for i in range(N):
        R[i,i] = np.linalg.norm(U[:,i])
        # Handling devision by zero by exiting the program as was advised in the forum
        if R[i,i] == 0:
            assert False #zero_devision_error(gram_schmidt._name_)
        Q[:,i] = np.divide(U[:,i],R[i,i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        for j in range(i+1,N):
            R[i,j] = np.dot(Q[:,i].transpose(),j])
            u = U[:,j] - R[i,j] * Q[:,i]
            U[:,j] = u
    return Q,R
    
a = np.array(
    [[ 1.00000000e+00,-1.98592571e-02,-1.00365698e-04,-1.45204974e-03,-9.95711793e-01,-1.77405377e-04,-7.68526195e-03],[-1.98592571e-02,1.00000000e+00,-1.77809186e-02,-1.55937174e-01,-9.80881385e-03,-2.05317715e-02,-2.01456899e-01],[-1.00365698e-04,-1.87979660e-01,-5.12368040e-05,-8.35323206e-01,-4.59007949e-05],[-1.45204974e-03,-8.69848133e-04,-3.64095785e-01,-5.55408776e-04],[-9.95711793e-01,-9.54867422e-05,-5.92716161e-03],[-1.77405377e-04,-5.55505343e-05],[-7.68526195e-03,-2.01456899e-01,-4.59007949e-05,-5.55408776e-04,-5.92716161e-03,-5.55505343e-05,1.00000000e+00]],dtype = np.float64)

print(gram_schmidt2(a))

输出:

(array([[ 7.08543467e-01,-5.53704898e-03,-2.70026740e-04,-3.47742384e-03,1.84840892e-01,-5.24814365e-01,-4.33966083e-01],[-1.40711469e-02,9.68398634e-01,-2.12833250e-02,1.19174521e-01,-1.98433167e-01,-3.04695775e-02,-8.39439437e-02],[-7.11134597e-05,-1.72252300e-02,7.59699130e-01,-1.47406821e-01,-1.01157914e-01,3.77137817e-01,-4.98362473e-01],[-1.02884036e-03,-1.51071666e-01,-1.41567550e-01,9.02766638e-01,-8.55711320e-02,2.12039165e-01,-2.99775521e-01],[-7.05505086e-01,-2.31427937e-02,3.84334272e-04,-6.68149305e-03,1.96907249e-01,-5.24473268e-01,-4.33402818e-01],[-1.25699421e-04,-1.98909561e-02,-6.34318769e-01,-3.82156774e-01,-9.76029595e-02,4.04531367e-01,-5.27283410e-01],[-5.44534215e-03,-1.95250685e-01,1.53606576e-03,-5.45941927e-02,-9.27687435e-01,-3.12618155e-01,-2.30333938e-02]]),array([[ 1.41134602e+00,-1.99608442e-02,4.42769473e-04,8.12375351e-04,-1.41083897e+00,5.39174765e-04,-3.87373035e-03],[ 0.00000000e+00,1.03234256e+00,1.05802339e-02,-2.91464191e-01,-2.58368570e-02,2.96333339e-02,-3.90075744e-01],0.00000000e+00,1.31655051e+00,-5.01046784e-02,9.97649491e-04,-1.21693202e+00,5.90252943e-03],1.05107524e+00,-4.80557952e-03,-5.90160540e-01,-7.90098043e-02],2.03928769e-02,2.21268065e-02,-8.90241765e-01],1.30829767e-02,-2.99495426e-01],9.31764881e-10]]))

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...