【问题标题】:Fastest way to find the maximum minimum value of 'connected' matrices找到“连接”矩阵的最大最小值的最快方法
【发布时间】:2021-11-27 06:31:36
【问题描述】:

question 给出了三个矩阵的答案,但我不确定如何将此逻辑应用于任意数量的成对连接矩阵:

f(i, j, k, l, ...) = min(A(i, j), B(i,k), C(i,l), D(j,k), E(j,l), F(k,l), ...)

其中A,B,... 是矩阵,i,j,... 是范围高达矩阵各自维度的索引。如果我们考虑n 索引,则有n(n-1)/2 对和矩阵。我想找到(i,j,k,...) 使得f(i,j,k,l,...) 最大化。我目前这样做如下:

import numpy as np
import itertools

#             i  j  k  l  ...
dimensions = [50,50,50,50]
n_dims = len(dimensions)

pairs = list(itertools.combinations(range(n_dims), 2))

# Construct the matrices A(i,j), B(i,k), ...
matrices = [];
for pair in pairs:
    matrices.append(np.random.rand(dimensions[pair[0]], dimensions[pair[1]]))


# All the different i,j,k,l... combinations
combinations = itertools.product(*list(map(np.arange,dimensions)))
combinations = np.asarray(list(combinations))

# Find the maximum minimum
vals = []

for i in range(len(pairs)):
    pair = pairs[i]
    matrix = matrices[i]
    vals.append(matrix[combinations[:,pair[0]], combinations[:,pair[1]]])


f = np.min(vals,axis=0)

best_indices = combinations[np.argmax(f)]

print(best_indices, np.max(f))

[5 17 17 18] 0.932985854758534

这比遍历所有 (i, j, k, l, ...) 更快,但花费大量时间构建组合和 vals 矩阵。是否有另一种方法可以做到这一点,其中(1)可以保留 numpy 矩阵计算的速度,(2)我不必构造内存密集型 vals 矩阵?

【问题讨论】:

    标签: python numpy matrix optimization memory


    【解决方案1】:

    这是 3D 解决方案的概括。我认为还有其他(更好的?)组织递归的方法,但这已经足够好了。它在 内完成了一个 6D 示例(dims 9x10^6 的乘积)

    示例运行,注意有时两种方法返回的索引不匹配。这是因为它们并不总是唯一的,有时不同的索引组合会产生相同的最小值。另请注意,最后我们只运行了一个巨大的 6D 9x10^12 示例。蛮力不再可行,聪明的方法大约需要 10 秒。

    trial 1
    results identical True
    results compatible True
    brute force 276.8830654968042 ms
    branch cut 9.971900499658659 ms
    trial 2
    results identical True
    results compatible True
    brute force 273.444719001418 ms
    branch cut 9.236706099909497 ms
    trial 3
    results identical True
    results compatible True
    brute force 274.2998780013295 ms
    branch cut 7.31226220013923 ms
    trial 4
    results identical True
    results compatible True
    brute force 273.0268925006385 ms
    branch cut 6.956217200058745 ms
    HUGE (100, 150, 200, 100, 150, 200) 9000000000000
    branch cut 10246.754082996631 ms
    

    代码:

    import numpy as np
    import itertools as it
    import functools as ft
    
    def bf(dims,pairs):
        dims,pairs = np.array(dims),np.array(pairs,object)
        n,m = len(dims),len(pairs)
        IDX = np.empty((m,n),object)
        Y,X = np.triu_indices(n,1)
        IDX[np.arange(m),Y] = slice(None)
        IDX[np.arange(m),X] = slice(None)
        idx = np.unravel_index(
            ft.reduce(np.minimum,(p[(*i,)] for p,i in zip(pairs,IDX))).argmax(),dims)
        return ft.reduce(np.minimum,(
            p[I] for p,I in zip(pairs,it.combinations(idx,2)))),idx
    
    def cut(dims,pairs,offs=None):
        n = len(dims)
        if n<3:
            if n==2:
                A = pairs[0] if offs is None else np.minimum(
                    pairs[0],np.minimum.outer(offs[0],offs[1]))
                idx = np.unravel_index(A.argmax(),dims)
                return A[idx],idx
            else:
                idx = offs[0].argmax()
                return offs[0][idx],(idx,)
        gmx = min(map(np.min,pairs))
        gidx = n * (0,)
        A = pairs[0] if offs is None else np.minimum(
            pairs[0],np.minimum.outer(offs[0],offs[1]))
        Y,X = np.unravel_index(A.argsort(axis=None)[::-1],dims[:2])
        for y,x in zip(Y,X):
            if A[y,x] <= gmx:
                return gmx,gidx
            coffs = [np.minimum(p1[y],p2[x])
                     for p1,p2 in zip(pairs[1:n-1],pairs[n-1:])]
            if not offs is None:
                coffs = [*map(np.minimum,coffs,offs[2:])]
            cmx,cidx = cut(dims[2:],pairs[2*n-3:],coffs)
            if cmx >= A[y,x]:
                return A[y,x],(y,x,*cidx)
            if gmx < cmx:
                gmx = min(A[y,x],cmx)
                gidx = y,x,*cidx
        return gmx,gidx
    
    from timeit import timeit
    
    IDX = 10,15,20,10,15,20
    
    for rep in range(4):
        print("trial",rep+1)
        pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
    
        print("results identical",cut(IDX,pairs)==bf(IDX,pairs))
        print("results compatible",cut(IDX,pairs)[1]==bf(IDX,pairs)[1])
        print("brute force",timeit(lambda:bf(IDX,pairs),number=2)*500,"ms")
        print("branch cut",timeit(lambda:cut(IDX,pairs),number=10)*100,"ms")
    
    IDX = 100,150,200,100,150,200
    pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
    print("HUGE",IDX,np.prod(IDX))
    print("branch cut",timeit(lambda:cut(IDX,pairs),number=1)*1000,"ms")
    

    【讨论】:

    • 这令人印象深刻!我试图概括它,但不能超越4 索引。请问offs这个变量是干什么用的?
    • 有没有一种简单的方法可以从cut 函数返回具有相同最大最小值的 所有 索引组合?
    • offs 用于“半消耗”矩阵(递归一次取出两个维度,例如 i,j,例如,具有索引 i,k 的矩阵必须沿 i 减少在 k 中留下一个向量)。获取所有索引应该是可能的,但可能会使代码更复杂。我能看到的最直接但可能不是最快的方法是将函数修改为生成器。
    猜你喜欢
    • 2021-11-25
    • 2021-11-26
    • 1970-01-01
    • 1970-01-01
    • 2014-06-17
    • 2018-10-21
    • 1970-01-01
    • 1970-01-01
    • 2020-10-11
    相关资源
    最近更新 更多