【问题标题】:Scala dot product very slow compared to Java与 Java 相比,Scala 点积非常慢
【发布时间】:2016-10-17 22:26:43
【问题描述】:

我是 Scala 的新手,我想以相同的性能水平翻译我的 Java 代码。

给定 n 个浮点向量和一个附加向量,我必须计算所有 n 个点积并得到最大值。

使用 Java 对我来说非常简单

public static void main(String[] args) {

    int N = 5000000;
    int R = 200;
    float[][] t = new float[N][R];
    float[] u = new float[R];

    Random r = new Random();

    for (int i = 0;i<N;i++) {
        for (int j = 0;j<R;j++) {
            if (i == 0) {
                u[j] = r.nextFloat();
            }
            t[i][j] = r.nextFloat();
        }
    }

    long ts = System.currentTimeMillis();
    float maxScore = -1.0f;

    for (int i = 0;i < N;i++) {
        float score = 0.0f;
        for (int j = 0; i < R;i++) {
            score += u[j] * t[i][j];
        }
        if (score > maxScore) {
            maxScore = score;
        }

    }

    System.out.println(System.currentTimeMillis() - ts);
    System.out.println(maxScore);

}

在我的机器上计算时间是 6 毫秒。

现在我必须用 Scala 来做

val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)

// Filling with random floats like in Java

val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f

for ( i <- 0 until N) {
  var score = 0.0f
  for (j <- 0 until R) {
    score += u(j) * t(i)(j)
  }
  if (score > maxScore) {
    maxScore = score
  }

}

println(System.currentTimeMillis() - ts)
println(maxScore);

上面的代码在我的机器上花费的时间超过了秒。 我的想法是Scala没有Java中的float[]之类的原始数组结构,而是由集合代替。索引 i 处的访问似乎比 Java 中原始数组的访问要慢。

下面的代码更慢:

val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max

需要 26 秒

我应该如何有效地迭代我的 2 个数组来计算这个?

非常感谢

【问题讨论】:

  • 您可以在 Scala 中使用数组...在 Scala 示例中,tu 是在哪里/如何定义的?
  • 糟糕,错过了!刚刚更新了帖子
  • @ogen 与您的问题无关,但作为旁注,0 until N0 to (N-1) 更惯用

标签: java arrays scala math


【解决方案1】:

好吧,很抱歉,这里奇怪的是你的 Java 实现有多快,而不是你的 Scala 有多慢 - 遍历 100 亿(!)个单元格的 6 毫秒听起来好得令人难以置信 - 确实 - 你在 Java 实现中有一个错字,这使得这段代码做得更少:

你有for (int j = 0; i &lt; R;i++)而不是for (int j = 0; j &lt; R;j++) - 这使得内部循环只运行200次而不是100亿次...

如果你解决了这个问题 - Scala 和 Java 的性能是相当的。

顺便说一句,这实际上是 Scala 的一个优势 - 更难让 for (j &lt;- 0 until R) 出错 :)

【讨论】:

  • 神圣!在您在您的答案中解释之后,我花了几秒钟才注意到错字。不错的收获!
  • LOL :D 我开始阅读生成的 scala 字节码,因为它对我来说太奇怪了。这只是java代码中的一个错字:D:D:D
  • 我完全同意这是 Scala 的优势。这也是为什么 i 和 j 实际上不应在任何语言中一起用作 for 循环索引的原因。 :) x 和 y 通常更容易看到这类东西。
【解决方案2】:

真正的问题只是一个错字(就像 Tzach Zohar 提到的那样),但如果你想提高性能,那么你可以用更直接的方式来做:

var i = 0
while (i < N) {
  var j = 0
  var score = 0.0f
  val t1: Array[Float] = t(i)
  while (j < R) {
    score += u(j) * t1(j)
    j += 1
  }
  if (score > maxScore) {
    maxScore = score
  }

  i += 1
}

这段代码 sn-p 的运行速度比 for-comprehension 版本快 10-20%。

或者!您可以使用“par”使第一个数组并行并在 map 中使用 while 循环:

val maxScore = t.par.map({
  arr =>
    var score = 0.0f
    var j = 0
    while (j < R) {
      score += u(j) * arr(j)
      j += 1
    }
    score
}).max

这段代码在我的机器上运行速度比 java 版本快 2-3 倍! 自己试试吧! :) 祝你好运

【讨论】:

    猜你喜欢
    • 2020-04-27
    • 2019-07-06
    • 2021-02-13
    • 2017-05-07
    • 1970-01-01
    • 2015-01-27
    • 2014-09-06
    • 1970-01-01
    • 2019-07-26
    相关资源
    最近更新 更多