【问题标题】:Program with threads for matrix multiplication用于矩阵乘法的线程程序
【发布时间】:2015-09-03 09:19:56
【问题描述】:

我正在尝试创建一个带有用于矩阵乘法的线程的 Java 程序。这是源代码:

import java.util.Random;

public class MatrixTest {
    //Creating the matrix
    static int[][] mat = new int[3][3];
    static int[][] mat2 = new int[3][3];
    static int[][] result = new int[3][3];

    public static void main(String[] args) {
        //Creating the object of random class
        Random rand = new Random();

        //Filling first matrix with random values
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[i].length; j++) {
                mat[i][j] = rand.nextInt(10);
            }
        }

        //Filling second matrix with random values
        for (int i = 0; i < mat2.length; i++) {
            for (int j = 0; j < mat2[i].length; j++) {
                mat2[i][j] = rand.nextInt(10);
            }
        }

        try {
            //Object of multiply Class
            Multiply multiply = new Multiply(3, 3);

            //Threads
            MatrixMultiplier thread1 = new MatrixMultiplier(multiply);
            MatrixMultiplier thread2 = new MatrixMultiplier(multiply);
            MatrixMultiplier thread3 = new MatrixMultiplier(multiply);

            //Implementing threads
            Thread th1 = new Thread(thread1);
            Thread th2 = new Thread(thread2);
            Thread th3 = new Thread(thread3);

            //Starting threads
            th1.start();
            th2.start();
            th3.start();

            th1.join();
            th2.join();
            th3.join();
        } catch (Exception e) {
            e.printStackTrace();
        }

        //Printing the result
        System.out.println("\n\nResult:");
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[i].length; j++) {
                System.out.print(result[i][j] + " ");
            }
            System.out.println();
        }
    }//End main
}//End Class

//Multiply Class
class Multiply extends MatrixTest {
    private int i;
    private int j;
    private int chance;

    public Multiply(int i, int j) {
        this.i = i;
        this.j = j;
        chance = 0;
    }

    //Matrix Multiplication Function
    public synchronized void multiplyMatrix() {
        int sum = 0;
        int a = 0;
        for (a = 0; a < i; a++) {
            sum = 0;
            for (int b = 0; b < j; b++) {
                sum = sum + mat[chance][b] * mat2[b][a];
            }
            result[chance][a] = sum;
        }

        if (chance >= i)
            return;
        chance++;
    }
}//End multiply class

//Thread Class
class MatrixMultiplier implements Runnable {
    private final Multiply mul;

    public MatrixMultiplier(Multiply mul) {
        this.mul = mul;
    }

    @Override
    public void run() {
        mul.multiplyMatrix();
    }
}

我刚刚在 Eclipse 上尝试过,它可以工作,但现在我想创建该程序的另一个版本,在该版本中,我为结果矩阵上的每个单元格使用一个线程。例如,我有两个 3x3 矩阵。所以结果矩阵将是 3x3。然后,我想用 9 个线程来计算结果矩阵的 9 个单元中的每一个。

谁能帮帮我?

【问题讨论】:

  • 我相信有人可以帮助你!他们在帮助您实现什么目标?
  • 我需要一些关于线程同步的帮助,因为在这种情况下我只使用 3 个线程。在另一个程序中,我想为每个单元使用 9 个或更多线程(当我的矩阵大于 3x3 时)一个。我认为我需要创建一个单元类,但目前我没有其他想法
  • 那么您是否正在寻找一种方法来创建n 线程,其中n 是单元格的数量?
  • 那么你想创建一个线程来计算矩阵的单个单元格吗?分解您的程序,使其具有可以计算单个单元格的函数调用,然后像上面对整个矩阵所做的那样从多个线程调用它。注意不要指望它运行得很快,线程的重量很重,您将遭受 CPU 缓存中的错误共享。由于 CPU 内部的并行性以及来自 Hotspot 的 SIMD 指令越来越多的使用,您已经拥有的版本比您可能意识到的更多并发。
  • 考虑一下这是一个准备考试的练习 :) 感谢您的合作 :)

标签: java multithreading matrix concurrency matrix-multiplication


【解决方案1】:

您可以按如下方式创建n 线程(注意:numberOfThreads 是您要创建的线程数。这将是单元格的数量):

List<Thread> threads = new ArrayList<>(numberOfThreads);

for (int x = 0; x < numberOfThreads; x++) {
   Thread t = new Thread(new MatrixMultiplier(multiply));
   t.start();
   threads.add(t);
}

for (Thread t : threads) {
   t.join();
}

【讨论】:

    【解决方案2】:

    请使用新的 Executor framework 创建线程,而不是手动进行管道。

    ExecutorService executor = Executors.newFixedThreadPool(numberOfThreadsInPool);
    for (int i = 0; i < numberOfThreads; i++) {
      Runnable worker = new Thread(new MatrixMultiplier(multiply));;
      executor.execute(worker);
    }
    executor.shutdown();
    while (!executor.isTerminated()) {
    }
    

    【讨论】:

      【解决方案3】:

      使用此代码,我认为我可以解决我的问题。我没有在方法中使用同步,但我认为在这种情况下没有必要。

      import java.util.Scanner;
      
      class MatrixProduct extends Thread {
          private int[][] A;
          private int[][] B;
          private int[][] C;
          private int rig, col;
          private int dim;
      
          public MatrixProduct(int[][] A, int[][] B, int[][] C, int rig, int col, int dim_com) {
              this.A = A;
              this.B = B;
              this.C = C;
              this.rig = rig;
              this.col = col;
              this.dim = dim_com;
          }
      
          public void run() {
              for (int i = 0; i < dim; i++) {
                  C[rig][col] += A[rig][i] * B[i][col];
              }
              System.out.println("Thread " + rig + "," + col + " complete.");
          }
      }
      
      public class MatrixMultiplication {
          public static void main(String[] args) {
              Scanner In = new Scanner(System.in);
      
              System.out.print("Row of Matrix A: ");
              int rA = In.nextInt();
              System.out.print("Column of Matrix A: ");
              int cA = In.nextInt();
              System.out.print("Row of Matrix B: ");
              int rB = In.nextInt();
              System.out.print("Column of Matrix B: ");
              int cB = In.nextInt();
              System.out.println();
      
              if (cA != rB) {
                  System.out.println("We can't do the matrix product!");
                  System.exit(-1);
              }
              System.out.println("The matrix result from product will be " + rA + " x " + cB);
              System.out.println();
              int[][] A = new int[rA][cA];
              int[][] B = new int[rB][cB];
              int[][] C = new int[rA][cB];
              MatrixProduct[][] thrd = new MatrixProduct[rA][cB];
      
              System.out.println("Insert A:");
              System.out.println();
              for (int i = 0; i < rA; i++) {
                  for (int j = 0; j < cA; j++) {
                      System.out.print(i + "," + j + " = ");
                      A[i][j] = In.nextInt();
                  }
              }
              System.out.println();
              System.out.println("Insert B:");
              System.out.println();
              for (int i = 0; i < rB; i++) {
                  for (int j = 0; j < cB; j++) {
                      System.out.print(i + "," + j + " = ");
                      B[i][j] = In.nextInt();
                  }
              }
              System.out.println();
      
              for (int i = 0; i < rA; i++) {
                  for (int j = 0; j < cB; j++) {
                      thrd[i][j] = new MatrixProduct(A, B, C, i, j, cA);
                      thrd[i][j].start();
                  }
              }
      
              for (int i = 0; i < rA; i++) {
                  for (int j = 0; j < cB; j++) {
                      try {
                          thrd[i][j].join();
                      } catch (InterruptedException e) {
                      }
                  }
              }
      
              System.out.println();
              System.out.println("Result");
              System.out.println();
              for (int i = 0; i < rA; i++) {
                  for (int j = 0; j < cB; j++) {
                      System.out.print(C[i][j] + " ");
                  }
                  System.out.println();
              }
          }
      }
      

      【讨论】:

        【解决方案4】:

        请考虑Matrix.javaMain.java,如下所示。

        public class Matrix extends Thread {
            private static int[][] a;
            private static int[][] b;
            private static int[][] c;
        
            /* You might need other variables as well */
            private int i;
            private int j;
            private int z1;
        
            private int s;
            private int k;
        
            public Matrix(int[][] A, final int[][] B, final int[][] C, int i, int j, int z1) { // need to change this, might
                // need some information
                a = A;
                b = B;
                c = C;
                this.i = i;
                this.j = j;
                this.z1 = z1; // a[0].length
            }
        
            public void run() {
                synchronized (c) {
                    // 3. How to allocate work for each thread (recall it is the run function which
                    // all the threads execute)
        
                    // Here this code implements the allocated work for perticular thread
                    // Each element of the resulting matrix will generate by a perticular thread
                    for (s = 0, k = 0; k < z1; k++)
                        s += a[i][k] * b[k][j];
                    c[i][j] = s;
                }
            }
        
            public static int[][] returnC() {
                return c;
            }
        
            public static int[][] multiply(final int[][] a, final int[][] b) {
                /*
                 * check if multipication can be done, if not return null allocate required
                 * memory return a * b
                 */
                final int x = a.length;
                final int y = b[0].length;
        
                final int z1 = a[0].length;
                final int z2 = b.length;
        
                if (z1 != z2) {
                    System.out.println("Cannnot multiply");
                    return null;
                }
        
                final int[][] c = new int[x][y];
                int i, j;
        
                // 1. How to use threads to parallelize the operation?
                // Every element in the resulting matrix will be determined by a different
                // thread
        
                // 2. How may threads to use?
                // x * y threads are used to generate the result.
                for (i = 0; i < x; i++)
                    for (j = 0; j < y; j++) {
                        try {
                            Matrix temp_thread = new Matrix(a, b, c, i, j, z1);
                            temp_thread.start();
        
                            // 4. How to synchronize?
        
                            // synchronized() is used with join() to guarantee that the perticular thread
                            // will be accessed first
                            temp_thread.join();
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                    }
                return Matrix.returnC();
            }
        }
        

        您可以使用Main.java 给出两个需要相乘的矩阵。

        class Main {
            public static int[][] a = {
                    {1, 1, 1},
                    {1, 1, 1},
                    {1, 1, 1}};
        
            public static int[][] b = {
                    {1},
                    {1},
                    {1}};
        
            public static void print_matrix(int[][] a) {
                for (int i = 0; i < a.length; i++) {
                    for (int j = 0; j < a[i].length; j++)
                        System.out.print(a[i][j] + " ");
                    System.out.println();
                }
            }
        
            public static void main(String[] args) {
                int[][] x = Matrix.multiply(a, b);
                print_matrix(x); // see if the multipication is correct
            }
        }
        

        【讨论】:

          【解决方案5】:

          简单来说,大家需要做的是,

          1) 创建 n 个(结果矩阵中的单元数)线程。分配他们的角色。 (例如:考虑 M X N,其中 M 和 N 是矩阵。“thread1”负责将 M 的 row_1 元素与 N 的 column_1 元素相乘并存储结果。这是结果矩阵 cell_1 的值。)

          2) 启动每个线程的进程。 (通过 start() 方法)

          3) 等到所有线程完成它们的进程并存储每个单元格的结果值。因为这些过程应该在显示结果矩阵之前完成。 (您可以通过 join() 方法和其他可能性来做到这一点)

          4) 现在,您可以显示结果矩阵了。

          注意:

          1) 由于在此示例中,共享资源(M 和 N)仅用于只读目的,因此您无需使用“同步”方法来访问它们。

          2) 可以看到,在这个程序中,有一组线程在运行,它们都需要自己达到一个特定的状态,然后才能继续整个程序的下一步。这种多线程编程模型被称为Barrier

          【讨论】:

            【解决方案6】:

            根据每个单元格的线程在 Eclipse 中尝试以下代码。效果很好,你可以检查一下。

            class ResMatrix {
                static int[][] arrres = new int[2][2];
            }
            
            class Matrix {
                int[][] arr = new int[2][2];
            
                void setV(int v) {
                    //int tmp = v;
                    for (int i = 0; i < 2; i++) {
                        for (int j = 0; j < 2; j++) {
                            arr[i][j] = v;
                            v = v + 1;
                        }
                    }
                }
            
                int[][] getV() {
                    return arr;
                }
            }
            
            class Mul extends Thread {
                public int row;
                public int col;
                Matrix m;
                Matrix m1;
            
                Mul(int row, int col, Matrix m, Matrix m1) {
                    this.row = row;
                    this.col = col;
                    this.m = m;
                    this.m1 = m1;
                }
            
                public void run() {
                    //System.out.println("Started Thread: " + Thread.currentThread().getName());
                    int tmp = 0;
                    for (int i = 0; i < 2; i++) {
                        tmp = tmp + this.m.getV()[row][i] * this.m1.getV()[i][col];
                    }
                    ResMatrix.arrres[row][col] = tmp;
                    System.out.println("Started Thread END: " + Thread.currentThread().getName());
                }
            }
            
            public class Test {
                //static int[][] arrres =new int[2][2];
                public static void main(String[] args) throws InterruptedException {
                    Matrix mm = new Matrix();
                    mm.setV(1);
                    Matrix mm1 = new Matrix();
                    mm1.setV(2);
            
                    for (int i = 0; i < 2; i++) {
                        for (int j = 0; j < 2; j++) {
                            Mul mul = new Mul(i, j, mm, mm1);
                            mul.start();
                            // mul.join();
                        }
                    }
            
                    for (int i = 0; i < 2; i++) {
                        for (int j = 0; j < 2; j++) {
                            System.out.println("VALUE: " + ResMatrix.arrres[i][j]);
                        }
                    }
                }
            }
            

            【讨论】:

              【解决方案7】:

              在我的解决方案中,我为每个工作人员分配了numRowForThread 的行数等于:(matA 的行数)/(线程数)。

              public class MatMulConcur {
                  private final static int NUM_OF_THREAD = 1;
                  private static Mat matC;
              
                  public static Mat matmul(Mat matA, Mat matB) {
                      matC = new Mat(matA.getNRows(), matB.getNColumns());
                      return mul(matA, matB);
                  }
              
                  private static Mat mul(Mat matA, Mat matB) {
                      int numRowForThread;
                      int numRowA = matA.getNRows();
                      int startRow = 0;
              
                      Worker[] myWorker = new Worker[NUM_OF_THREAD];
              
                      for (int j = 0; j < NUM_OF_THREAD; j++) {
                          if (j < NUM_OF_THREAD - 1) {
                              numRowForThread = (numRowA / NUM_OF_THREAD);
                          } else {
                              numRowForThread = (numRowA / NUM_OF_THREAD) + (numRowA % NUM_OF_THREAD);
                          }
                          myWorker[j] = new Worker(startRow, startRow + numRowForThread, matA, matB);
                          myWorker[j].start();
                          startRow += numRowForThread;
                      }
              
                      for (Worker worker : myWorker) {
                          try {
                              worker.join();
                          } catch (InterruptedException e) {
              
                          }
                      }
                      return matC;
                  }
              
                  private static class Worker extends Thread {
                      private int startRow, stopRow;
                      private Mat matA, matB;
              
                      public Worker(int startRow, int stopRow, Mat matA, Mat matB) {
                          super();
                          this.startRow = startRow;
                          this.stopRow = stopRow;
                          this.matA = matA;
                          this.matB = matB;
                      }
              
                      @Override
                      public void run() {
                          for (int i = startRow; i < stopRow; i++) {
                              for (int j = 0; j < matB.getNColumns(); j++) {
                                  double sum = 0;
                                  for (int k = 0; k < matA.getNColumns(); k++) {
                                      sum += matA.get(i, k) * matB.get(k, j);
                                  }
                                  matC.set(i, j, sum);
                              }
                          }
                      }
                  }
              }
              

              class Mat 在哪里,我使用了这个实现:

              public class Mat {
                  private double[][] mat;
              
                  public Mat(int n, int m) {
                      mat = new double[n][m];
                  }
              
                  public void set(int i, int j, double v) {
                      mat[i][j] = v;
                  }
              
                  public double get(int i, int j) {
                      return mat[i][j];
                  }
              
                  public int getNRows() {
                      return mat.length;
                  }
              
                  public int getNColumns() {
                      return mat[0].length;
                  }
              }
              

              【讨论】:

                猜你喜欢
                • 1970-01-01
                • 1970-01-01
                • 1970-01-01
                • 2015-01-23
                • 1970-01-01
                • 1970-01-01
                • 1970-01-01
                • 1970-01-01
                • 1970-01-01
                相关资源
                最近更新 更多