【发布时间】:2016-06-02 19:12:49
【问题描述】:
我有一个执行大量矩阵乘法的程序。我想我可以通过减少代码中的循环数来加快它的速度,看看它会快多少(我稍后会尝试一个矩阵数学库)。事实证明,它一点也不快。我已经能够用一些示例代码复制这个问题。我的猜测是 testOne() 会比 testTwo() 快,因为它不会创建任何新数组,而且它有三分之一的循环。在我的机器上,它的运行时间是原来的两倍:
具有 5000 个 epoch 的 testOne 的持续时间:657,loopCount:64000000
5000 个 epoch 的 testTwo 的持续时间:365,loopCount:192000000
我的猜测是multOne() 比multTwo() 慢,因为在multOne() 中,CPU 不会像在multTwo() 中那样写入顺序内存地址。听起来对吗?任何解释将不胜感激。
import java.util.Random;
public class ArrayTest {
double[] arrayOne;
double[] arrayTwo;
double[] arrayThree;
double[][] matrix;
double[] input;
int loopCount;
int rows;
int columns;
public ArrayTest(int rows, int columns) {
this.rows = rows;
this.columns = columns;
this.loopCount = 0;
arrayOne = new double[rows];
arrayTwo = new double[rows];
arrayThree = new double[rows];
matrix = new double[rows][columns];
Random random = new Random();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
matrix[i][j] = random.nextDouble();
}
}
}
public void testOne(double[] input, int epochs) {
this.input = input;
this.loopCount = 0;
long start = System.currentTimeMillis();
long duration;
for (int i = 0; i < epochs; i++) {
multOne();
}
duration = System.currentTimeMillis() - start;
System.out.println("Duration for testOne with " + epochs + " epochs: " + duration + ", loopCount: " + loopCount);
}
public void multOne() {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
arrayOne[i] += matrix[i][j] * arrayOne[i] * input[j];
arrayTwo[i] += matrix[i][j] * arrayTwo[i] * input[j];
arrayThree[i] += matrix[i][j] * arrayThree[i] * input[j];
loopCount++;
}
}
}
public void testTwo(double[] input, int epochs) {
this.loopCount = 0;
long start = System.currentTimeMillis();
long duration;
for (int i = 0; i < epochs; i++) {
arrayOne = multTwo(matrix, arrayOne, input);
arrayTwo = multTwo(matrix, arrayTwo, input);
arrayThree = multTwo(matrix, arrayThree, input);
}
duration = System.currentTimeMillis() - start;
System.out.println("Duration for testTwo with " + epochs + " epochs: " + duration + ", loopCount: " + loopCount);
}
public double[] multTwo(double[][] matrix, double[] array, double[] input) {
double[] newArray = new double[rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
newArray[i] += matrix[i][j] * array[i] * input[j];
loopCount++;
}
}
return newArray;
}
public static void main(String[] args) {
int rows = 100;
int columns = 128;
ArrayTest arrayTest = new ArrayTest(rows, columns);
Random random = new Random();
double[] input = new double[columns];
for (int i = 0; i < columns; i++) {
input[i] = random.nextDouble();
}
arrayTest.testOne(input, 5000);
arrayTest.testTwo(input, 5000);
}
}
【问题讨论】:
-
Java 中的基准测试真的、真的、很难。我不相信你的测量——编写告诉你事实相反的基准真的很容易。如果您需要可靠的数据,请使用 JMH 等。
-
这里有一个提示,可以帮助您开始自己回答问题(尽管很高兴看到一个制作精良的问题):更改 mult1 以使用新数组而不是
arrayOne +=、arrayTwo +=和 @ 987654328@,然后在方法末尾分配arrayOne、arrayTwo和arrayThree。 -
这可能是真的。
-
跟踪花费的时间而不是执行的操作会很有帮助。您可以
System.nanotime()获取当前时间,然后在您的操作结束后减去差值。 -
是的,以纳秒为单位计算时间。
标签: java performance matrix-multiplication