【问题标题】:Why does optimized prime-factor counting algorithm run slower为什么优化的素数计数算法运行速度较慢
【发布时间】:2019-11-13 04:20:39
【问题描述】:

您好,我在网上看到了一个计算数字的不同素因数的答案,它看起来不是最优的。所以我尝试改进它,但在一个简单的基准测试中,我的变体比原来的要慢得多。

该算法计算一个数字的不同质因数。原始使用 HashSet 来收集因子,然后使用 size 来获取它们的数量。我的“改进”版本使用 int 计数器,并将 while 循环分解为 if/while 以避免不必要的调用。

更新:tl/dr(有关详细信息,请参阅已接受的答案)

原始代码有一个性能错误调用 Math.sqrt 编译器修复了不必要的:

int n = ...;
// sqrt does not need to be recomputed if n does not change
for (int i = 3; i <= Math.sqrt(n); i += 2) {
    while (n % i == 0) {
        n /= i;
    }
}

编译器优化了 sqrt 调用,使其仅在 n 更改时发生。但是通过使循环内容稍微复杂一些(虽然没有功能上的改变),编译器停止了优化,并且在每次迭代时都会调用 sqrt。

原始问题

public class PrimeFactors {

    // fast version, takes 10s for input 8
    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    // slow version, takes 19s for input 8
    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

    static int findNumberWithNPrimeFactors(final int n) {
        for (int i = 3; ; i++) {
            // switch implementations
            if (countPrimeFactorsCounter(i) == n) {
            // if (countPrimeFactorsSet(i) == n) {
                return i;
            }
        }
    }

    public static void main(String[] args) {
        findNumberWithNPrimeFactors(8); // benchmark warmup
        findNumberWithNPrimeFactors(8);
        long start = System.currentTimeMillis();
        int result = findNumberWithNPrimeFactors(n);
        long duration = System.currentTimeMillis() - start;

        System.out.println("took ms " + duration + " to find " + result);
    }
}

原始版本的输出始终在 10 秒左右(在 java8 上),而“优化”版本更接近 20 秒(两者都打印相同的结果)。实际上,只需将单个 while 循环更改为包含 while 循环的 if 块,就已经将原始方法的速度减慢了一半。

使用-Xint以解释模式运行JVM,优化后的版本运行速度提高了3倍。使用-Xcomp 使两个实现以相似的速度运行。 因此,似乎 JIT 可以优化带有单个 while 循环和 HashSet 的版本,而不是带有简单 int 计数器的版本。

适当的微基准 (How do I write a correct micro-benchmark in Java?) 会告诉我其他信息吗? 有没有我忽略的性能优化原则(例如Java performance tips)?

【问题讨论】:

  • 您是否分析了代码?
  • 不,否则我会添加结果。
  • 那么也许你应该这样做。
  • @JamesKPolk 我一开始也是这么想的,但我意识到他正在调用一个迭代每个数字(从3 开始)的方法来搜索具有8 素因子的数字。
  • 你在while循环之前做了一个额外的除法;将其转入 do-while 循环。

标签: java performance compiler-optimization hoisting loop-invariant


【解决方案1】:

我将您的示例转换为JMH benchmark 以进行公平测量,实际上set 变体的出现速度是counter 的两倍:

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5   717,976 ±  7,232  ops/ms
PrimeFactors.set      thrpt    5  1410,705 ± 15,894  ops/ms

为了找出原因,我使用内置的-prof xperfasm 分析器重新运行了基准测试。碰巧counter 方法花费了超过 60% 的时间执行 vsqrtsd 指令 - 显然,Math.sqrt(n) 的编译对应物。

  0,02%   │  │  │     │  0x0000000002ab8f3e: vsqrtsd %xmm0,%xmm0,%xmm0    <-- Math.sqrt
 61,27%   │  │  │     │  0x0000000002ab8f42: vcvtsi2sd %r10d,%xmm1,%xmm1

同时set方法最热的指令是idiv,是n % i编译的结果。

             │  │ ││  0x0000000002ecb9e7: idiv   %ebp               ;*irem
 55,81%      │  ↘ ↘│  0x0000000002ecb9e9: test   %edx,%edx

Math.sqrt 是一个缓慢的操作并不奇怪。但是为什么在第一种情况下执行得更频繁呢?

线索是您在优化期间所做的代码的转换。您将一个简单的 while 循环包装到一个额外的 if 块中。这使得控制流更加复杂,因此 JIT 无法将 Math.sqrt 计算提升到循环之外,并且必须在每次迭代时重新计算。

我们需要帮助 JIT 编译器来恢复性能。让我们手动将 Math.sqrt 计算提升到循环之外。

    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
            sn = Math.sqrt(n);     // recompute after n changes
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
                sn = Math.sqrt(n);     // recompute after n changes
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

现在counter 方法变快了!甚至比set 快一点(这是意料之中的,因为它执行相同数量的计算,不包括 Set 开销)。

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5  1513,228 ± 13,046  ops/ms
PrimeFactors.set      thrpt    5  1411,573 ± 10,004  ops/ms

请注意,set 的性能没有改变,因为 JIT 能够自己进行相同的优化,这要归功于更简单的控制流图。

结论: Java 性能是一件非常复杂的事情,尤其是在谈到微优化时。 JIT 优化很脆弱,如果没有像 JMH 和分析器这样的专门工具,很难理解 JVM 的想法。

【讨论】:

  • 终于有人解决了这个问题。而且,更好的是,回答了它!
  • 所以原始代码的主要问题是它在理论上计算 sqrt 过于频繁,但编译器可以修复这个问题,只要其余代码足够简单,可以查看 n 是否可以更改或不是。所以原始代码有一个性能,但幸运的是在这种情况下自动修复。
  • @tkruse 完全正确
【解决方案2】:

首先,测试中有两组操作:测试因子,并记录这些因子。在切换实现时,使用 Set 与使用 ArrayList(在我的重写中,如下),与简单地计算因素会有所不同。

其次,我发现时间变化很大。这是从 Eclipse 运行的。我不清楚是什么导致了大的变化。

我的“经验教训”是要注意测量的确切内容。是否打算测量分解算法本身(while 循环的成本加上算术运算)?是否应包括时间记录因素?

一个次要的技术点:在这个实现中敏锐地感觉到缺少 multiple-value-setq(在 lisp 中可用)。人们更愿意将余数和整数除法作为单个操作执行,而不是将它们写成两个不同的步骤。从语言和算法研究的角度来看,这值得一看。

以下是分解实现的三种变体的时序结果。第一个来自最初的(未优化的)实现,但改为使用简单的 List 而不是更难的时间 Set 来存储因素。第二个是您的优化,但仍然使用列表进行跟踪。第三是你的优化,但包括改变计数因素。

  18 -  3790 1450 2410 (average of 10 iterations)
  64 -  1630 1220  260 (average of 10 iterations)
1091 - 16170 2850 1180 (average of 10 iterations)
1092 -  2720 1370  380 (average of 10 iterations)

4096210 - 28830 5430 9120 (average of  10 iterations, trial 1)
4096210 - 18380 6190 5920 (average of  10 iterations, trial 2)
4096210 - 10072 5816 4836 (average of 100 iterations, trial 1)
4096210 -  7202 5036 3682 (average of 100 iterations, trial 1)

---

Test value [ 18 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621713914872600 (ns) ]
End   [ 1621713914910500 (ns) ]
Delta [ 37900 (ns) ]
Avg   [ 3790 (ns) ]
Factors: [2, 3, 3]
Times [optimized]
Start [ 1621713915343500 (ns) ]
End   [ 1621713915358000 (ns) ]
Delta [ 14500 (ns) ]
Avg   [ 1450 (ns) ]
Factors: [2, 3, 3]
Times [counting]
Start [ 1621713915550400 (ns) ]
End   [ 1621713915574500 (ns) ]
Delta [ 24100 (ns) ]
Avg   [ 2410 (ns) ]
Factors: 3
---
Test value [ 64 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621747046013900 (ns) ]
End   [ 1621747046030200 (ns) ]
Delta [ 16300 (ns) ]
Avg   [ 1630 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [optimized]
Start [ 1621747046337800 (ns) ]
End   [ 1621747046350000 (ns) ]
Delta [ 12200 (ns) ]
Avg   [ 1220 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [counting]
Start [ 1621747046507900 (ns) ]
End   [ 1621747046510500 (ns) ]
Delta [ 2600 (ns) ]
Avg   [ 260 (ns) ]
Factors: 6
---
Test value [ 1091 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621687024226500 (ns) ]
End   [ 1621687024388200 (ns) ]
Delta [ 161700 (ns) ]
Avg   [ 16170 (ns) ]
Factors: [1091]
Times [optimized]
Start [ 1621687024773200 (ns) ]
End   [ 1621687024801700 (ns) ]
Delta [ 28500 (ns) ]
Avg   [ 2850 (ns) ]
Factors: [1091]
Times [counting]
Start [ 1621687024954900 (ns) ]
End   [ 1621687024966700 (ns) ]
Delta [ 11800 (ns) ]
Avg   [ 1180 (ns) ]
Factors: 1
---
Test value [ 1092 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621619636267500 (ns) ]
End   [ 1621619636294700 (ns) ]
Delta [ 27200 (ns) ]
Avg   [ 2720 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [optimized]
Start [ 1621619636657100 (ns) ]
End   [ 1621619636670800 (ns) ]
Delta [ 13700 (ns) ]
Avg   [ 1370 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [counting]
Start [ 1621619636895300 (ns) ]
End   [ 1621619636899100 (ns) ]
Delta [ 3800 (ns) ]
Avg   [ 380 (ns) ]
Factors: 5
---
Test value [ 4096210 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621652753519800 (ns) ]
End   [ 1621652753808100 (ns) ]
Delta [ 288300 (ns) ]
Avg   [ 28830 (ns) ]
Factors: [2, 5, 19, 21559]
Times [optimized]
Start [ 1621652754116300 (ns) ]
End   [ 1621652754170600 (ns) ]
Delta [ 54300 (ns) ]
Avg   [ 5430 (ns) ]
Factors: [2, 5, 19, 21559]
Times [counting]
Start [ 1621652754323500 (ns) ]
End   [ 1621652754414700 (ns) ]
Delta [ 91200 (ns) ]
Avg   [ 9120 (ns) ]
Factors: 4

这是我对测试代码的重写。最受关注的是findFactorsfindFactorsOptfindFactorsCount

package my.tests;

import java.util.ArrayList;
import java.util.List;

public class PrimeFactorsTest {

    public static void main(String[] args) {
        if ( args.length < 2 ) {
            System.out.println("Usage: " + PrimeFactorsTest.class.getName() + " testValue warmupIterations testIterations");
            return;
        }

        int testValue = Integer.valueOf(args[0]);
        int warmCount = Integer.valueOf(args[1]);
        int testCount = Integer.valueOf(args[2]);

        if ( testValue <= 2 ) {
            System.out.println("Test value [ " + testValue + " ] must be at least 2.");
            return;
        } else {
            System.out.println("Test value [ " + testValue + " ]");
        }
        if ( warmCount <= 0 ) {
            System.out.println("Warm-up count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Warm-up count [ " + warmCount + " ]");
        }
        if ( testCount <= 1 ) {
            System.out.println("Test count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Test count [ " + testCount + " ]");
        }

        timedFactors(testValue, warmCount, testCount);
        timedFactorsOpt(testValue, warmCount, testCount);
        timedFactorsCount(testValue, warmCount, testCount);
    }

    public static void timedFactors(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [non-optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactors(int n, List<Integer> factors) {
        while ( n % 2 == 0 ) {
            n /= 2;
            factors.add( Integer.valueOf(2) );
        }

        for ( int factor = 3; factor <= Math.sqrt(n); factor += 2 ) {
            while ( n % factor == 0 ) {
                n /= factor;
                factors.add( Integer.valueOf(factor) );
            }
        }

        if ( n > 2 ) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsOpt(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();
        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactorsOpt(int n, List<Integer> factors) {
        if ( n % 2 == 0 ) {
            n /= 2;

            Integer factor = Integer.valueOf(2); 
            factors.add(factor);

            while (n % 2 == 0) {
                n /= 2;

                factors.add(factor);
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;

                Integer factor = Integer.valueOf(factorValue); 
                factors.add(factor);

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    factors.add(factor);
                }
            }
        }

        if (n > 2) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsCount(int testValue, int warmCount, int testCount) {
        int numFactors = 0;

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            numFactors = findFactorsCount(testValue);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            numFactors = findFactorsCount(testValue);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [counting]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + numFactors);
    }

    public static int findFactorsCount(int n) {
        int numFactors = 0;

        if ( n % 2 == 0 ) {
            n /= 2;
            numFactors++;

            while (n % 2 == 0) {
                n /= 2;
                numFactors++;
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;
                numFactors++;

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    numFactors++;
                }
            }
        }

        if (n > 2) {
            numFactors++;
        }

        return numFactors;
    }
}

【讨论】:

  • 提出的问题很简单:为什么countPrimeFactorsCounter() 似乎花费的时间几乎是countPrimeFactorsSet() 的两倍。它从来都不是最好或更好的算法。我不明白这是如何回答这个问题的。
  • 谢谢,我可以尝试使用此框架重新运行以查看其他性能模式。在这里选择特定数字作为输入有点令人困惑,因为 2 的倍数的大数字可能需要比相同范围内的素数少得多的代码步骤,因此 1091 和 1092 可能有如此大的差异。可能可以选择不同的算法来测量计算与输入一致的效果。
  • 嗨。是的,事实证明,时间很大程度上取决于被分解的数字。我正在考虑尝试绘制一系列输入的时间,只是为了初步了解变化范围,以及时间如何取决于特定因素。需要注意的是,对于非常小的值(如最初的“8”),常数因素可能会主导结果。这将包括设置开销,这(我在想)将强烈影响时代。
【解决方案3】:

如果在这里,请先阻止: for (int i = 3; i &lt;= Math.sqrt(n); i += 2) { if (n % i == 0) {...

应该在循环之外,

其次,您可以使用不同的方法执行此代码,例如:

while (n % 2 == 0) { Current++; n /= 2; }

你可以改变它: if(n % 2 ==0) { current++; n=n%2; }

基本上,由于您的方法,您应该避免循环内的条件或指令:

(findNumberWithNPrimeFactors)

算法的复杂度是每个循环的复杂度 (findNumberWithNPrimeFactors) X (迭代次数)

如果你在你的循环中添加一个测试或做作,你会得到一个 + 1 ( 复杂度 (findNumberWithNPrimeFactors) X ( 迭代次数 ) )

【讨论】:

    【解决方案4】:

    以下通过将 n 相除,使 Math.sqrt 变得多余。 不断与较小的平方根进行比较甚至可能是最慢的操作。

    那么do-while会是更好的风格。

    static int countPrimeFactorsCounter2(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            ++count; // only add on first division
            do {
                n /= 2;
            } while (n % 2 == 0);
        }
        for (int i = 3; i <= n; i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                do {
                    n /= i;
                } while (n % i == 0);
            }
        }
        //if (n > 2) {
        //    ++count;
        //}
        return count;
    }
    

    使用平方根的逻辑谬误是基于∀ a, b: a.b = n 你只需要尝试a &lt; √n。然而,在 n-dividing 循环中,您只需保存一个步骤。请注意,sqrt 是在每个奇数 i 处计算的。

    【讨论】:

    • Math.sqrt 不是多余的(例如,当n 是素数并且内部循环中没有除法时)。尝试运行您的“优化”版本 - 这将需要很长时间才能完成。此外,do-while 在性能方面没有任何区别。但最重要的是,您的答案与原始问题无关,这根本不是关于优化算法,而是关于两个给定方法之间的性能差异
    • @apangin i &lt;= n 也应该考虑 n 是素数。 do-while 是化妆品。 sqrt + 循环是主要区别;但我承认我没有给出比较。你的回答非常好。感谢您仔细检查我的答案。我没有尝试我的方法,但似乎有 % 成本的额外循环步骤。感谢您的洞察力。
    猜你喜欢
    • 2020-06-04
    • 1970-01-01
    • 1970-01-01
    • 2019-06-18
    • 2019-01-11
    • 1970-01-01
    • 1970-01-01
    • 2019-07-21
    • 1970-01-01
    相关资源
    最近更新 更多