这是 3D 解决方案的概括。我认为还有其他(更好的?)组织递归的方法,但这已经足够好了。它在 内完成了一个 6D 示例(dims 9x10^6 的乘积)
示例运行,注意有时两种方法返回的索引不匹配。这是因为它们并不总是唯一的,有时不同的索引组合会产生相同的最小值。另请注意,最后我们只运行了一个巨大的 6D 9x10^12 示例。蛮力不再可行,聪明的方法大约需要 10 秒。
trial 1
results identical True
results compatible True
brute force 276.8830654968042 ms
branch cut 9.971900499658659 ms
trial 2
results identical True
results compatible True
brute force 273.444719001418 ms
branch cut 9.236706099909497 ms
trial 3
results identical True
results compatible True
brute force 274.2998780013295 ms
branch cut 7.31226220013923 ms
trial 4
results identical True
results compatible True
brute force 273.0268925006385 ms
branch cut 6.956217200058745 ms
HUGE (100, 150, 200, 100, 150, 200) 9000000000000
branch cut 10246.754082996631 ms
代码:
import numpy as np
import itertools as it
import functools as ft
def bf(dims,pairs):
dims,pairs = np.array(dims),np.array(pairs,object)
n,m = len(dims),len(pairs)
IDX = np.empty((m,n),object)
Y,X = np.triu_indices(n,1)
IDX[np.arange(m),Y] = slice(None)
IDX[np.arange(m),X] = slice(None)
idx = np.unravel_index(
ft.reduce(np.minimum,(p[(*i,)] for p,i in zip(pairs,IDX))).argmax(),dims)
return ft.reduce(np.minimum,(
p[I] for p,I in zip(pairs,it.combinations(idx,2)))),idx
def cut(dims,pairs,offs=None):
n = len(dims)
if n<3:
if n==2:
A = pairs[0] if offs is None else np.minimum(
pairs[0],np.minimum.outer(offs[0],offs[1]))
idx = np.unravel_index(A.argmax(),dims)
return A[idx],idx
else:
idx = offs[0].argmax()
return offs[0][idx],(idx,)
gmx = min(map(np.min,pairs))
gidx = n * (0,)
A = pairs[0] if offs is None else np.minimum(
pairs[0],np.minimum.outer(offs[0],offs[1]))
Y,X = np.unravel_index(A.argsort(axis=None)[::-1],dims[:2])
for y,x in zip(Y,X):
if A[y,x] <= gmx:
return gmx,gidx
coffs = [np.minimum(p1[y],p2[x])
for p1,p2 in zip(pairs[1:n-1],pairs[n-1:])]
if not offs is None:
coffs = [*map(np.minimum,coffs,offs[2:])]
cmx,cidx = cut(dims[2:],pairs[2*n-3:],coffs)
if cmx >= A[y,x]:
return A[y,x],(y,x,*cidx)
if gmx < cmx:
gmx = min(A[y,x],cmx)
gidx = y,x,*cidx
return gmx,gidx
from timeit import timeit
IDX = 10,15,20,10,15,20
for rep in range(4):
print("trial",rep+1)
pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
print("results identical",cut(IDX,pairs)==bf(IDX,pairs))
print("results compatible",cut(IDX,pairs)[1]==bf(IDX,pairs)[1])
print("brute force",timeit(lambda:bf(IDX,pairs),number=2)*500,"ms")
print("branch cut",timeit(lambda:cut(IDX,pairs),number=10)*100,"ms")
IDX = 100,150,200,100,150,200
pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
print("HUGE",IDX,np.prod(IDX))
print("branch cut",timeit(lambda:cut(IDX,pairs),number=1)*1000,"ms")