【问题标题】:Strassen Matrix Multiplication -- close, but still with bugsStrassen 矩阵乘法——接近,但仍有错误
【发布时间】:2012-10-12 21:04:47
【问题描述】:

我正在尝试在 Python 中实现 Strassen 矩阵乘法。我已经让它工作了一些。这是我的代码:

a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]

def new_m(p, q): # create a matrix filled with 0s
    matrix = [[0 for row in range(p)] for col in range(q)]
    return matrix

def straight(a, b): # multiply the two matrices
    if len(a[0]) != len(b): # if # of col != # of rows:
        return "Matrices are not m*n and n*p"
    else:
        p_matrix = new_m(len(a), len(b[0]))
        for i in range(len(a)):
            for j in range(len(b[0])):
                for k in range(len(b)):
                    p_matrix[i][j] += a[i][k]*b[k][j]
    return p_matrix

def split(matrix): # split matrix into quarters 
    a = matrix
    b = matrix
    c = matrix
    d = matrix
    while(len(a) > len(matrix)/2):
        a = a[:len(a)//2]
        b = b[:len(b)//2]
        c = c[len(c)//2:]
        d = d[len(d)//2:]
    while(len(a[0]) > len(matrix[0])/2):
        for i in range(len(a[0])//2):
            a[i] = a[i][:len(a[i])//2]
            b[i] = b[i][len(b[i])//2:]
            c[i] = c[i][:len(c[i])//2]
            d[i] = d[i][len(d[i])//2:]
    return a,b,c,d

def add_m(a, b):
    if type(a) == int:
        d = a + b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] + b[i][j])
            d.append(c)
    return d

def sub_m(a, b):
    if type(a) == int:
        d = a - b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] - b[i][j])
            d.append(c)
    return d


def strassen(a, b, q):
    # base case: 1x1 matrix
    if q == 1:
        d = [[0]]
        d[0][0] = a[0][0] * b[0][0]
        return d
    else:
        #split matrices into quarters
        a11, a12, a21, a22 = split(a)
        b11, b12, b21, b22 = split(b)

        # p1 = (a11+a22) * (b11+b22)
        p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2)

        # p2 = (a21+a22) * b11
        p2 = strassen(add_m(a21,a22), b11, q/2)

        # p3 = a11 * (b12-b22)
        p3 = strassen(a11, sub_m(b12,b22), q/2)

        # p4 = a22 * (b12-b11)
        p4 = strassen(a22, sub_m(b12,b11), q/2)

        # p5 = (a11+a12) * b22
        p5 = strassen(add_m(a11,a12), b22, q/2)

        # p6 = (a21-a11) * (b11+b12)
        p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2)

        # p7 = (a12-a22) * (b21+b22)
        p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2)


        # c11 = p1 + p4 - p5 + p7
        c11 = add_m(sub_m(add_m(p1, p4), p5), p7)

        # c12 = p3 + p5
        c12 = add_m(p3, p5)

        # c21 = p2 + p4
        c21 = add_m(p2, p4)

        # c22 = p1 + p3 - p2 + p6
        c22 = add_m(sub_m(add_m(p1, p3), p2), p6)

        c = new_m(len(c11)*2,len(c11)*2)
        for i in range(len(c11)):
            for j in range(len(c11)):
                c[i][j]                   = c11[i][j]
                c[i][j+len(c11)]          = c12[i][j]
                c[i+len(c11)][j]          = c21[i][j]
                c[i+len(c11)][j+len(c11)] = c22[i][j]

        return c

print "Strassen Outputs:"
print strassen(a, b, 4)
print "Should be:"
print straight(a, b)

我包含了直接矩阵乘法以参考正确的所需输出。基本上会发生这种情况:

施特拉森输出:

[[10, 14, 22, 26], [32, 36, 48, 52], [58, 66, 70, 78], [80, 88, 96, 104]]

应该是:

[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]

我不确定问题的根源是什么,这意味着我无法解决它!

【问题讨论】:

    标签: python matrix implementation multiplication strassen


    【解决方案1】:

    不应该这样吗:

    # p4 = a22 * (b12-b11)
    p4 = strassen(a22, sub_m(b12,b11), q/2)
    

    是:

    # p4 = a22 * (b21-b11)
    p4 = strassen(a22, sub_m(b21,b11), q/2)
    

    改为?

    ~/coding$ python -i strass.py
    Strassen Outputs:
    [[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
    Should be:
    [[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
    >>> import numpy
    >>> def check():
    ...     for i in range(100):
    ...         a = numpy.random.randint(0, 10,size=(4,4)).tolist()
    ...         b = numpy.random.randint(0, 10,size=(4,4)).tolist()
    ...         assert strassen(a,b,4) == straight(a,b)
    ...         assert (numpy.array(strassen(a,b,4)) == numpy.dot(a,b)).all()
    ...     print 'hooray!'
    ... 
    >>> check()
    hooray!
    

    【讨论】:

    • 对于那些获得TypeError: slice indices must be integers or None or have an __index__ method 的人,在Python 3.X 中,在split 方法中使用地板除法// 而不是/
    【解决方案2】:

    我用 numpy 写了另一个版本来简化 add() 和 sub()...

    import numpy as np
    def straight(a, b): 
        if len(a[0]) != len(b): return "Matrices are not m*n and n*p" 
        p_matrix = np.zeros((len(a), len(b[0])))
        p_matrix += [[np.sum([a[i][k] * b[k][j] for k in range(len(b))]) for j in range(len(b[0]))] for i in range(len(a))]
        return p_matrix
    def split(matrix):  # split matrix into quarters
        row, col = matrix.shape
        return matrix[:row//2, :col//2], matrix[:row//2, col//2:], matrix[row//2:, :col//2], matrix[row//2:, col//2:]
    def strassen(a, b):
        q = len(a)
        if q == 1:  # base case: 1x1 matrix
            return a * b
        a11, a12, a21, a22 = split(a)
        b11, b12, b21, b22 = split(b)
        p1 = strassen(a11 + a22, b11 + b22)  # p1 = (a11 + a22) * (b11 + b22)
        p2 = strassen(a21 + a22, b11)        # p2 = (a21 + a22) * b11
        p3 = strassen(a11, b12 - b22)        # p3 = a11 * (b12 - b22)
        p4 = strassen(a22, b21 - b11)        # p4 = a22 * (b21 - b11)
        p5 = strassen(a11 + a12, b22)        # p5 = (a11 + a12) * b22
        p6 = strassen(a21 - a11, b11 + b12)  # p6 = (a21 - a11) * (b11 + b12)
        p7 = strassen(a12 - a22, b21 + b22)  # p7 = (a12 - a22) * (b21 + b22)
        c11 = p1 + p4 - p5 + p7  # c11 = p1 + p4 - p5 + p7
        c12 = p3 + p5            # c12 = p3 + p5
        c21 = p2 + p4            # c21 = p2 + p4
        c22 = p1 + p3 - p2 + p6  # c22 = p1 + p3 - p2 + p6
        c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22)))) 
        return c
    def check():
        a = np.random.randint(0, 10, size=(16, 16))
        b = np.random.randint(0, 10, size=(16, 16))
        assert (strassen(a, b) == straight(a, b)).all()
        assert (np.array(strassen(a, b)) == np.dot(a, b)).all()
        print('Hooray!')
    check()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2010-12-27
      • 1970-01-01
      • 2012-07-14
      • 1970-01-01
      • 2012-11-13
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多