【问题标题】:numpy.linalg.solve with right-hand side of more than three dimensionsnumpy.linalg.solve 右侧超过三个维度
【发布时间】:2018-07-01 10:03:59
【问题描述】:

我正在尝试求解具有 3x3 矩阵 a 和右侧 b 任意形状 (3, ...) 的方程组。如果b 有一个或两个维度,numpy.linalg.solve 就可以了。但它分解为更多维度:

import numpy

a = numpy.random.rand(3, 3)

b = numpy.random.rand(3)
numpy.linalg.solve(a, b)  # okay

b = numpy.random.rand(3, 4)
numpy.linalg.solve(a, b)  # okay

b = numpy.random.rand(3, 4, 5)
numpy.linalg.solve(a, b)  # ERR
ValueError: solve: Input operand 1 has a mismatch in its core 
dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 5 is 
different from 3)

我希望输出数组sol 的形状为(3, 4, 5),对应于右侧b[:, i, j] 的解是sol[:, i, j]

关于如何最好地解决这个问题的任何提示?

【问题讨论】:

  • 这意味着什么? numpy.linalg.solve 求解矩阵方程,而numpy.random.rand(3, 4, 5) 的维度太多,不能成为矩阵。
  • 我自然希望输出一个形状为(3, 4, 5) 的数组;右手边b(:, i, j) 的解决方案是sol(:, i, j),就像其他情况一样。
  • 所以您认为二维b 行为是一种对一维行为的广播,并期望它扩展到更多维度?这不是这里的设计;一方面,NumPy 广播在左侧而不是右侧进行广播。这里的设计是一个矩阵方程求解器,具有一种特殊情况,用于将一维b 视为列矩阵。您必须像 Warren Weckesser 建议的那样,将您的 b 重新塑造成一个矩阵并重新塑造解决方案。

标签: python numpy linear-algebra


【解决方案1】:

暂时将b整形为(3, 20),求解线性系统,然后将得到的数组整形为b的原始形状(3、4、5):

In [34]: a = numpy.random.rand(3, 3)
In [35]: b = numpy.random.rand(3, 4, 5)

In [36]: x = numpy.linalg.solve(a, b.reshape(b.shape[0], -1)).reshape(b.shape)

使用np.swapaxesb的第一个轴与第二个轴交换,求解线性系统,然后恢复轴:

In [58]: x = np.swapaxes(np.linalg.solve(a, np.swapaxes(b, 0, 1)), 0, 1)

完整性检查:

In [38]: np.einsum('ij,jkl', a, x)
Out[38]: 
array([[[ 0.44859955,  0.22967928,  0.74336067,  0.47440575,  0.53798895],
        [ 0.80045696,  0.54138958,  0.89870834,  0.56862419,  0.28217437],
        [ 0.02093982,  0.78534718,  0.77208236,  0.41568151,  0.95100661],
        [ 0.03820421,  0.47067312,  0.71928294,  0.30852615,  0.64454321]],

       [[ 0.31757072,  0.30527186,  0.36768759,  0.95869289,  0.86601996],
        [ 0.60616508,  0.69927063,  0.53470332,  0.88906606,  0.76066344],
        [ 0.95411847,  0.51116677,  0.29338398,  0.04418815,  0.96210206],
        [ 0.23449429,  0.64159963,  0.7732404 ,  0.4314741 ,  0.81279619]],

       [[ 0.6399571 ,  0.57640652,  0.0186913 ,  0.66304489,  0.83372239],
        [ 0.28426522,  0.62367363,  0.37163699,  0.78217433,  0.90573787],
        [ 0.91066088,  0.06699638,  0.43079394,  0.00263537,  0.399102  ],
        [ 0.17711441,  0.48724858,  0.05526752,  0.34251648,  0.94059739]]])

In [39]: b
Out[39]: 
array([[[ 0.44859955,  0.22967928,  0.74336067,  0.47440575,  0.53798895],
        [ 0.80045696,  0.54138958,  0.89870834,  0.56862419,  0.28217437],
        [ 0.02093982,  0.78534718,  0.77208236,  0.41568151,  0.95100661],
        [ 0.03820421,  0.47067312,  0.71928294,  0.30852615,  0.64454321]],

       [[ 0.31757072,  0.30527186,  0.36768759,  0.95869289,  0.86601996],
        [ 0.60616508,  0.69927063,  0.53470332,  0.88906606,  0.76066344],
        [ 0.95411847,  0.51116677,  0.29338398,  0.04418815,  0.96210206],
        [ 0.23449429,  0.64159963,  0.7732404 ,  0.4314741 ,  0.81279619]],

       [[ 0.6399571 ,  0.57640652,  0.0186913 ,  0.66304489,  0.83372239],
        [ 0.28426522,  0.62367363,  0.37163699,  0.78217433,  0.90573787],
        [ 0.91066088,  0.06699638,  0.43079394,  0.00263537,  0.399102  ],
        [ 0.17711441,  0.48724858,  0.05526752,  0.34251648,  0.94059739]]])

使用np.allclose(),这样您就不必手动检查数字,尤其是对于大型数组:

In [32]: b_ = np.einsum('ij,jkl', a, x)

In [33]: np.allclose(b, b_)
Out[33]: True

【讨论】:

  • @kmario23,谢谢,我正在考虑回来做同样的事情!
  • np.einsum() 的美丽 ;) np.matmul 在看到 2D 和 3D 组合时失败了!我不得不求助于您使用的相同技巧:np.matmul(a, x.reshape(x.shape[0], -1)).reshape(x.shape)
  • @kmario23 你好,你似乎对此很了解。你能帮我解决我的问题吗? stackoverflow.com/questions/58152952/…
  • @WarrenWeckesser 你能看看我的问题吗?我面临着类似的问题stackoverflow.com/questions/58152952/…
【解决方案2】:

我想补充一点,manual 明确指出:

a : (..., M, M) array_like

系数矩阵。

b : {(..., M,), (..., M, K)}, array_like

坐标或“因变量”值。

所以最后一个维度必须与a的最后两个维度相同(M)。除此之外,它的行为与您预期的一样 - 更多维度是可能的,返回与B 具有相同维度的结果。这样Ax=B 的解就自然而然地计算出来了,并且尺寸自动转换——只需求解许多具有尺寸 (M,K) 的解的方程组,并将它们嵌入到外部尺寸中。在您的情况下,3 在开头而不是在中间会混淆算法。 3维示例;

>>> a=np.random.rand(9).reshape(3,3)
>>> b=np.random.rand(12).reshape(2,3,2)
>>> np.linalg.solve(a,b)
array([[[-0.63673083,  0.57508091],
        [ 0.87653408,  0.46092677],
        [ 0.61128222, -0.19641607]],

       [[-0.91645601,  1.30939652],
        [ 0.83591936, -0.17006344],
        [ 0.19086912,  0.29082206]]])

【讨论】:

    猜你喜欢
    • 2019-12-12
    • 2021-12-26
    • 2014-06-19
    • 1970-01-01
    • 2021-11-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多