【问题标题】:reducing loops with numpy用 numpy 减少循环
【发布时间】:2021-04-25 10:59:45
【问题描述】:

我们正在尝试实现给定的 Modified Gram Schmidt 算法:

我们首先尝试通过以下方式实现第 5-7 行:

for j in range(i+1, N):
    R[i, j] = np.matmul(Q[:, i].transpose(), U[:, j])
    u = U[:, j] - R[i, j] * Q[:, i]
    U[:, j] = u

为了减少运行时间,我们尝试用这样的矩阵运算替换循环:

# we changed the inner loop to matrix operations in order to improve running time
R[i, i + 1:] = np.matmul(Q[:, i] , U[:, i + 1:])
U[:, i + 1:] = U[:, i + 1:] - R[i, i + 1:] * np.transpose(np.tile(Q[:, i], (N - i - 1, 1)))

结果不一样,但非常相似。我们的二审有问题吗?

谢谢!

编辑: 完整的功能是:

def gram_schmidt2(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,R
    """
    N = np.shape(A)[0]
    U = A.copy()
    Q = np.zeros((N, N), dtype=np.float64)
    R = np.zeros((N, N), dtype=np.float64)
    for i in range(N):
        R[i, i] = np.linalg.norm(U[:, i])
        # Handling devision by zero by exiting the program as was advised in the forum
        if R[i, i] == 0:
            zero_devision_error(gram_schmidt._name_)
        Q[:, i] = np.divide(U[:, i], R[i, i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        for j in range(i+1, N):
            R[i, j] = np.matmul(Q[:, i].transpose(), U[:, j])
            u = U[:, j] - R[i, j] * Q[:, i]
            U[:, j] = u
    return Q, R

和:

def gram_schmidt1(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,R
    """
    N = np.shape(A)[0]
    U = A.copy()
    Q = np.zeros((N, N), dtype=np.float64)
    R = np.zeros((N, N), dtype=np.float64)
    for i in range(N):
        R[i, i] = np.linalg.norm(U[:, i])
        # Handling devision by zero by exiting the program as was advised in the forum
        if R[i, i] == 0:
            zero_devision_error(gram_schmidt._name_)
        Q[:, i] = np.divide(U[:, i], R[i, i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        R[i, i + 1:] = np.matmul(Q[:, i] , U[:, i + 1:])
        U[:, i + 1:] = U[:, i + 1:] - R[i, i + 1:] * np.transpose(np.tile(Q[:, i], (N - i - 1, 1)))
    return Q, R

当我们在矩阵上运行函数时:

[[ 1.00000000e+00 -1.98592571e-02 -1.00365698e-04 -1.45204974e-03
  -9.95711793e-01 -1.77405377e-04 -7.68526195e-03]
 [-1.98592571e-02  1.00000000e+00 -1.77809186e-02 -1.55937174e-01
  -9.80881385e-03 -2.05317715e-02 -2.01456899e-01]
 [-1.00365698e-04 -1.77809186e-02  1.00000000e+00 -1.87979660e-01
  -5.12368040e-05 -8.35323206e-01 -4.59007949e-05]
 [-1.45204974e-03 -1.55937174e-01 -1.87979660e-01  1.00000000e+00
  -8.69848133e-04 -3.64095785e-01 -5.55408776e-04]
 [-9.95711793e-01 -9.80881385e-03 -5.12368040e-05 -8.69848133e-04
   1.00000000e+00 -9.54867422e-05 -5.92716161e-03]
 [-1.77405377e-04 -2.05317715e-02 -8.35323206e-01 -3.64095785e-01
  -9.54867422e-05  1.00000000e+00 -5.55505343e-05]
 [-7.68526195e-03 -2.01456899e-01 -4.59007949e-05 -5.55408776e-04
  -5.92716161e-03 -5.55505343e-05  1.00000000e+00]]

我们得到不同的这些输出:

对于 gram shmidt 1:

问:

[[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
  -4.91879501e-02 -4.90769704e-01  1.58268518e-01]
 [-2.78569770e-04  7.14001661e-01 -2.70586659e-03 -2.70735367e-02
   5.78840577e-01  2.37376069e-01  1.97835647e-02]
 [-2.48309244e-03 -2.34709092e-03  7.38351181e-01  2.63187853e-01
  -3.35473487e-01  3.38823696e-01  3.36320600e-01]
 [-4.27658449e-03 -2.12584453e-03 -6.70730760e-01  3.82666405e-01
  -3.44451231e-01  3.46085878e-01 -7.71559024e-01]
 [-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
   5.94568750e-01  2.38329853e-01 -2.76969906e-01]
 [-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
  -2.78679804e-01  2.78781202e-01  0.00000000e+00]
 [-6.72739327e-01  1.73894101e-04  2.25707383e-03  1.69052581e-02
  -1.26723666e-02 -5.77668322e-01 -4.35238424e-01]]

R:

[[ 1.36233007e+00  1.11436069e-03  1.04418015e-02  1.27072186e-02
   1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
 [ 0.00000000e+00  1.40055740e+00  5.29057231e-04  1.44628716e-03
  -1.40014587e+00  3.57535802e-04  2.25417515e-03]
 [ 0.00000000e+00  0.00000000e+00  1.35440586e+00 -1.33059602e+00
   6.67148806e-04 -3.51561140e-02  2.23809829e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.81147599e-01
   1.33951520e-02 -9.55057795e-01  2.36910667e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   3.37143743e-02 -1.97436093e-01  7.90539705e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  3.40545951e-01 -1.75971454e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  3.50740324e-16]]

对于 gram shmidt 2:

问:

    [[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
  -4.91879501e-02 -4.90769704e-01  4.55677949e-01]
 [-2.78569770e-04  7.14001661e-01 -2.70586659e-03 -2.70735367e-02
   5.78840577e-01  2.37376069e-01 -1.89865812e-01]
 [-2.48309244e-03 -2.34709092e-03  7.38351181e-01  2.63187853e-01
  -3.35473487e-01  3.38823696e-01  9.49329061e-02]
 [-4.27658449e-03 -2.12584453e-03 -6.70730760e-01  3.82666405e-01
  -3.44451231e-01  3.46085878e-01 -4.36691368e-01]
 [-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
   5.94568750e-01  2.38329853e-01 -1.13919487e-01]
 [-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
  -2.78679804e-01  2.78781202e-01 -1.51892650e-01]
 [-6.72739327e-01  1.73894101e-04  2.25707383e-03  1.69052581e-02
  -1.26723666e-02 -5.77668322e-01 -7.21490087e-01]]

R:

[[ 1.36233007e+00  1.11436069e-03  1.04418015e-02  1.27072186e-02
   1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
 [ 0.00000000e+00  1.40055740e+00  5.29057231e-04  1.44628716e-03
  -1.40014587e+00  3.57535802e-04  2.25417515e-03]
 [ 0.00000000e+00  0.00000000e+00  1.35440586e+00 -1.33059602e+00
   6.67148806e-04 -3.51561140e-02  2.23809829e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.81147599e-01
   1.33951520e-02 -9.55057795e-01  2.36910667e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   3.37143743e-02 -1.97436093e-01  7.90539705e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  3.40545951e-01 -1.75971454e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  3.65463051e-16]]

【问题讨论】:

  • 如果不是太大,能否提供完整代码?还有一些真实的数据示例。我们需要一些可重现的代码和数据来尝试我们的解决方案并为您提供帮助。
  • 我们添加了,谢谢!
  • 您确定您的第一个代码有效吗?您的函数 gram_schmidt2() 在您的测试数据上引发异常。
  • 它表示U[:, j] 超出范围,因为j 的范围高达NN 是输入数组的第一维,而不是第二维。 U 的维度与A 相同,这意味着j 应该在第一个索引中,如U[j, :],而不是U[:, j]
  • 可能A 应该是方阵,以便您的gram_schmidt2() 工作。

标签: python numpy linear-algebra array-broadcasting


【解决方案1】:

以下代码以更有效的方式执行您想要的操作:

        Q_i = Q[:, i].reshape(1,-1)
        R[i,i+1:] = np.matmul(Q_i , U[:,i+1:])
        U[:,i+1:] -=  np.multiply(R[i,i+1:] , Q_i.T)

第一行只是为了方便,让代码更具可读性。

除最后一行外,一切都与您的原始提案相同。最后一行执行逐元素乘法,这最终是您在内部循环的最后一行中所做的。

关于结果的差异:

你的代码没问题,两者都是一样的。当你处理浮点数时,你不应该测试为A == B。相反,我建议您检查两个数组的不同之处。

特别是跑步

Q1,R1 = gram_schmidt2(A)
Q2,R2 = gram_schmidt1(A)

(Q1 - Q2).mean()
(R1 - R2).mean()

分别给出:

-5.4997372770547595e-09 and -5.2465803662044656e-18

已经非常接近于 0。 1e-18 低于 dtype np.float64 的错误,所以你很好。

如果你运行差异3*0.1 - 0.3(大约1e-17),你可以检查这个

矩阵 Q 的误差较大,因为它来自浮点数之间的除法,如果矩阵元素的大小很小(这里有时会出现这种情况),这会增加误差。

关于运行时间:

在运行您的两个版本的代码时,我得到相似的运行时间:(243 µs ± 25.5 µs 使用循环,241 µs ± 6.82 µs 使用您的第二个版本);而这里提供的代码实现152 µs ± 1.49 µs

【讨论】:

  • “矩阵 Q 的误差较大,因为它来自浮点数之间的除法,如果矩阵元素的数量级很小(这里有时会出现这种情况),则会增加误差。”你是什​​么意思?我们在两者中划分相同的次数
  • 为什么是:" Q_i = Q[:, i].reshape(1,-1) neccerary?
  • 呃,对不起,没有意识到这一点,所以这个论点不再适用。但是,(平均)差异仍然很小。您可以通过乘以 Q*R 并与 A 进行比较来更好地检查您的答案。这应该让您更好地了解您得到的答案是否正确。
  • Q_i 的定义只是为了方便,因为它将在接下来的两行中使用。 .reshape(1,-1) 是必要的,因为我们需要将这个切片 (Q[:,i]) 转换为矩阵形式,这样它就可以转置,并通过 numpy 以我们想要的方式广播。
【解决方案2】:

我可以建议你使用Numba,它是一个很棒的速度优化器,它可以通过将许多 Python 程序 JIT 编译成 C++ 和机器代码来将其提升 50-200 倍。

要安装 numba,只需执行一次 python -m pip install numba

下面是采用你的算法实现 numba 的代码,大部分只是在第一行函数之前的 @numba.njit 装饰器。

在 numba 代码中,您只需编写常规 Python 循环和任何数学计算,即使不使用 Numpy,您的最终代码也会非常快,大多数时候甚至比任何 Numpy 代码都快。

我以您的 gram_schmidt2() 函数为基础,仅将 np.multiply() 替换为 np.dot(),因为 Numba 似乎只实现了 np.dot() 功能。

Try it online!

import numpy as np, numba

@numba.njit(cache = True, fastmath = True, parallel = True)
def gram_schmidt2(A):
    """
    decomposes a matrix A ∈ R into a product A = QR of an
    orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
    the main diagonal are zero)

    :return: Q,R
    """
    N = np.shape(A)[0]
    U = A.copy()
    Q = np.zeros((N, N), dtype=np.float64)
    R = np.zeros((N, N), dtype=np.float64)
    for i in range(N):
        R[i, i] = np.linalg.norm(U[:, i])
        # Handling devision by zero by exiting the program as was advised in the forum
        if R[i, i] == 0:
            assert False #zero_devision_error(gram_schmidt._name_)
        Q[:, i] = np.divide(U[:, i], R[i, i])
        # we changed the inner loop to matrix operatins in oreder to improve running time
        for j in range(i+1, N):
            R[i, j] = np.dot(Q[:, i].transpose(), U[:, j])
            u = U[:, j] - R[i, j] * Q[:, i]
            U[:, j] = u
    return Q, R
    
a = np.array(
    [[ 1.00000000e+00, -1.98592571e-02, -1.00365698e-04, -1.45204974e-03,
      -9.95711793e-01, -1.77405377e-04, -7.68526195e-03],
     [-1.98592571e-02,  1.00000000e+00, -1.77809186e-02, -1.55937174e-01,
      -9.80881385e-03, -2.05317715e-02, -2.01456899e-01],
     [-1.00365698e-04, -1.77809186e-02,  1.00000000e+00, -1.87979660e-01,
      -5.12368040e-05, -8.35323206e-01, -4.59007949e-05],
     [-1.45204974e-03, -1.55937174e-01, -1.87979660e-01,  1.00000000e+00,
      -8.69848133e-04, -3.64095785e-01, -5.55408776e-04],
     [-9.95711793e-01, -9.80881385e-03, -5.12368040e-05, -8.69848133e-04,
       1.00000000e+00, -9.54867422e-05, -5.92716161e-03],
     [-1.77405377e-04, -2.05317715e-02, -8.35323206e-01, -3.64095785e-01,
      -9.54867422e-05,  1.00000000e+00, -5.55505343e-05],
     [-7.68526195e-03, -2.01456899e-01, -4.59007949e-05, -5.55408776e-04,
      -5.92716161e-03, -5.55505343e-05,  1.00000000e+00]]
, dtype = np.float64)

print(gram_schmidt2(a))

输出:

(array([[ 7.08543467e-01, -5.53704898e-03, -2.70026740e-04,
        -3.47742384e-03,  1.84840892e-01, -5.24814365e-01,
        -4.33966083e-01],
       [-1.40711469e-02,  9.68398634e-01, -2.12833250e-02,
         1.19174521e-01, -1.98433167e-01, -3.04695775e-02,
        -8.39439437e-02],
       [-7.11134597e-05, -1.72252300e-02,  7.59699130e-01,
        -1.47406821e-01, -1.01157914e-01,  3.77137817e-01,
        -4.98362473e-01],
       [-1.02884036e-03, -1.51071666e-01, -1.41567550e-01,
         9.02766638e-01, -8.55711320e-02,  2.12039165e-01,
        -2.99775521e-01],
       [-7.05505086e-01, -2.31427937e-02,  3.84334272e-04,
        -6.68149305e-03,  1.96907249e-01, -5.24473268e-01,
        -4.33402818e-01],
       [-1.25699421e-04, -1.98909561e-02, -6.34318769e-01,
        -3.82156774e-01, -9.76029595e-02,  4.04531367e-01,
        -5.27283410e-01],
       [-5.44534215e-03, -1.95250685e-01,  1.53606576e-03,
        -5.45941927e-02, -9.27687435e-01, -3.12618155e-01,
        -2.30333938e-02]]),
array([[ 1.41134602e+00, -1.99608442e-02,  4.42769473e-04,
         8.12375351e-04, -1.41083897e+00,  5.39174765e-04,
        -3.87373035e-03],
       [ 0.00000000e+00,  1.03234256e+00,  1.05802339e-02,
        -2.91464191e-01, -2.58368570e-02,  2.96333339e-02,
        -3.90075744e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.31655051e+00,
        -5.01046784e-02,  9.97649491e-04, -1.21693202e+00,
         5.90252943e-03],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.05107524e+00, -4.80557952e-03, -5.90160540e-01,
        -7.90098043e-02],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  2.03928769e-02,  2.21268065e-02,
        -8.90241765e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.30829767e-02,
        -2.99495426e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         9.31764881e-10]]))

【讨论】:

    猜你喜欢
    • 2021-04-23
    • 1970-01-01
    • 1970-01-01
    • 2011-06-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-07-16
    • 2011-12-22
    相关资源
    最近更新 更多