【问题标题】:Efficient cartesian product excluding items不包括项目的高效笛卡尔积
【发布时间】:2020-05-02 01:28:30
【问题描述】:

我正在尝试将 11 个值的所有可能组合重复 80 次,但过滤掉总和大于 1 的情况。下面的代码实现了我想要做的,但需要几天才能运行:

import numpy as np
import itertools

unique_values = np.linspace(0.0, 1.0, 11)

lst = []
for p in itertools.product(unique_values , repeat=80):
    if sum(p)<=1:
        lst.append(p)

上述解决方案可行,但需要太多时间。此外,在这种情况下,我必须定期将“lst”保存到磁盘中并释放内存以避免任何内存错误。后一部分很好,但代码需要几天(或几周)才能完成。

还有其他选择吗?

【问题讨论】:

  • unique_values = np.linspace(0.0, 1.0, 11) 是真实的还是一个例子?
  • 在numpy中围绕itertools.product的stackoverflow有几种实现,可能会更快,否则,我会尝试用C或其他快速语言来实现。另外,为什么np.linspace(0.0, 1.0, 11) 而不是range(12)
  • np.linspace(0.0, 1.0, 11) 是真实的。上面的例子正是我想要得到的。如果我使用range(12),我将无法检查总和是否低于 1
  • @Stergios 那么你只需要 10 个的 42 个分区,缩小 10 倍:wolframalpha.com/input/?i=partitions+of+10
  • 这最终是为了什么?存储数千兆字节的大多数为零的列表真的是您的最佳选择吗?

标签: python itertools cartesian-product


【解决方案1】:

好的,这样会更有效率,你可以像这样使用生成器,并根据需要取值:

def get_solution(uniques, length, constraint):
    if length == 1:
        for u in uniques[uniques <= constraint + 1e-8]:
            yield u
    else:
        for u in uniques[uniques <= constraint + 1e-8]:
            for s in get_solution(uniques, length - 1, constraint - u):
                yield np.hstack((u, s))
g = get_solution(unique_values, 4, 1)
for _ in range(5):
    print(next(g))

打印

[0. 0. 0. 0.]
[0.  0.  0.  0.1]
[0.  0.  0.  0.2]
[0.  0.  0.  0.3]
[0.  0.  0.  0.4]

与你的功能比较:

def get_solution_product(uniques, length, constraint):
    return np.array([p for p in product(uniques, repeat=length) if np.sum(p) <= constraint + 1e-8])
%timeit np.vstack(list(get_solution(unique_values, 5, 1)))
346 ms ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit get_solution_product(unique_values, 5, 1)
2.94 s ± 256 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

【讨论】:

  • @Viachelsav 您的解决方案没有返回我需要的所有值。例如,它不返回有效的解决方案 [0.8 0.2 0.0 0.0 0.0] 或 [0.8 0.0 0.0 0.2 0.0]。事实上,它恰好返回了一半可能的解决方案。
  • 我接受你的回答。我需要添加的只是运行循环两次。一次用于unique_values,一次用于unique_values[::-1]
  • @Stergios 这很奇怪,因为它实际上产生了所有正确的答案。检查list(filter(lambda x: x[0] == 0.8, get_solution(unique_values, 5, 1)))
【解决方案2】:

OP 只需要 10 个分区,但这是我在此期间编写的一些通用代码。

def find_combinations(values, max_total, repeat):
    if not (repeat and max_total > 0):
        yield ()
        return
    for v in values:
        if v <= max_total:
            for sub_comb in find_combinations(values, max_total - v, repeat - 1):
                yield (v,) + sub_comb


def main():
    all_combinations = find_combinations(range(1, 11), 10, 80)
    unique_combinations = {
        tuple(sorted(t))
        for t in all_combinations
    }
    for comb in sorted(unique_combinations):
        print(comb)

main()

【讨论】:

  • 我并不是在寻找 10 的分区。例如,在我的情况下,向量 [0.1, 0, 0, 0....] 被接受(只要总和是
  • @Stergios 我明白了,我没有回到这个问题,因为这个问题很难解决。
  • 没关系。我刚刚在您的解决方案下添加了我之前的评论,以供其他人参考。
猜你喜欢
  • 2013-10-22
  • 2020-09-26
  • 2020-12-19
  • 2012-09-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多