矩阵乘法是线性代数里面很常用的一种计算方式,但当矩阵的阶太大时,人为计算就很麻烦了,因此对矩阵乘法问题的解决是算法很重要的方面。
矩阵的表达方式
首先,我们要先解决矩阵的表示方式。毫无疑问的,矩阵乘法应该用二维数组表示,但二维数组不能作为函数的参数传递,因此我们应该换一种方式表示二维数组。这里可以用双重指针表示二维数组,代码如下:
int **a1 = (int **)malloc(N*sizeof(int *)); //N是矩阵阶数,这里为a1申请一个空间,空间大小是N个指针,表示a1的每个元素都是指针
//为a1的每个元素申请空间
for(int i=0;i<N;i++){
a1[i]=(int *)malloc(N*sizeof(int));
}
//可以用a1[i][j]表示第i行第j列的元素
解决了矩阵的表达方式,下面介绍三种矩阵的乘法算法。
普通法
最普通的矩阵乘法,就是按照定义的方法去解决。代码如下。
int multi(int **a1,int **a2,int n){
//定义一个新的矩阵(二维数组)
int **x = (int **)malloc(n*sizeof(int *));
for(int i=0;i<n;i++){
x[i]=(int *)malloc(n*sizeof(int));
}
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
x[i][j] = 0;
for(int k=0;k<n;k++){
x[i][j]+=a1[i][k]*a2[k][j];
}
}
}
return x;
}
可以看到,这种方法的时间复杂度为O(n3),显然是不可取的。
分块法
如图所示,这种方法的代码与下一种方法实现类似,不po出来
时间复杂度分析:T(n)=8T(n/2)+O(n);
Strassen’s 算法
Strassen’s 算法的思路见下图
对其时间复杂度分析:T(n)=7T(n/2)+O(n);
显然,虽然O的分析是相同的,但其系数由8变成了7,实际时间会少很多,当n够大时,甚至可以减半
代码如下:
#include <stdio.h>
#include <stdlib.h>
#define N 16
int S1[N/2][N/2];
int S2[N/2][N/2];
int S3[N/2][N/2];
int S4[N/2][N/2];
int S5[N/2][N/2];
int S6[N/2][N/2];
int S7[N/2][N/2];
int ** add (int **a1,int **a2,int n){
int **c = (int **)malloc(n*sizeof(int *));
for(int i=0;i<n;i++){
c[i]=(int *)malloc(n*sizeof(int));
}
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
c[i][j] = a1[i][j]+a2[i][j];
}
}
return c;
}
int ** sub (int **a1,int **a2,int n){
int **c = (int **)malloc(n*sizeof(int *));
for(int i=0;i<n;i++){
c[i]=(int *)malloc(n*sizeof(int));
}
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
c[i][j] = a1[i][j]-a2[i][j];
}
}
return c;
}
int ** Strassen (int **a1,int **a2,int n){
int **x = (int **)malloc(n*sizeof(int *));
for(int i=0;i<n;i++){
x[i]=(int *)malloc(n*sizeof(int));
}
if(n==2){
x[0][0] = a1[0][0]*a2[0][0]+a1[0][1]*a2[1][0];
x[0][1] = a1[0][0]*a2[0][1]+a1[0][1]*a2[1][1];
x[1][0] = a1[1][0]*a2[0][0]+a1[1][1]*a2[1][0];
x[1][1] = a1[1][0]*a2[0][1]+a1[1][1]*a2[1][1];
return x;
}
int ** a = (int **)malloc(n/2*sizeof(int *));
int ** b = (int **)malloc(n/2*sizeof(int *));
int ** c = (int **)malloc(n/2*sizeof(int *));
int ** d = (int **)malloc(n/2*sizeof(int *));
int ** e = (int **)malloc(n/2*sizeof(int *));
int ** f = (int **)malloc(n/2*sizeof(int *));
int ** g = (int **)malloc(n/2*sizeof(int *));
int ** h = (int **)malloc(n/2*sizeof(int *));
int ** I = (int **)malloc(n/2*sizeof(int *));
int ** J = (int **)malloc(n/2*sizeof(int *));
int ** k = (int **)malloc(n/2*sizeof(int *));
int ** l = (int **)malloc(n/2*sizeof(int *));
int ** s1 = (int **)malloc(n/2*sizeof(int *));
int ** s2 = (int **)malloc(n/2*sizeof(int *));
int ** s3 = (int **)malloc(n/2*sizeof(int *));
int ** s4 = (int **)malloc(n/2*sizeof(int *));
int ** s5 = (int **)malloc(n/2*sizeof(int *));
int ** s6 = (int **)malloc(n/2*sizeof(int *));
int ** s7 = (int **)malloc(n/2*sizeof(int *));
for(int i=0;i<n/2;i++){
a[i] = (int *)malloc(n/2*sizeof(int));
b[i] = (int *)malloc(n/2*sizeof(int));
c[i] = (int *)malloc(n/2*sizeof(int));
d[i] = (int *)malloc(n/2*sizeof(int));
e[i] = (int *)malloc(n/2*sizeof(int));
f[i] = (int *)malloc(n/2*sizeof(int));
g[i] = (int *)malloc(n/2*sizeof(int));
h[i] = (int *)malloc(n/2*sizeof(int));
I[i] = (int *)malloc(n/2*sizeof(int));
J[i] = (int *)malloc(n/2*sizeof(int));
k[i] = (int *)malloc(n/2*sizeof(int));
l[i] = (int *)malloc(n/2*sizeof(int));
s1[i] = (int *)malloc(n/2*sizeof(int));
s2[i] = (int *)malloc(n/2*sizeof(int));
s3[i] = (int *)malloc(n/2*sizeof(int));
s4[i] = (int *)malloc(n/2*sizeof(int));
s5[i] = (int *)malloc(n/2*sizeof(int));
s6[i] = (int *)malloc(n/2*sizeof(int));
s7[i] = (int *)malloc(n/2*sizeof(int));
}
for(int i=0;i<n/2;i++){
for(int j=0;j<n/2;j++){
a[i][j] = a1[i][j];
b[i][j] = a1[i][j+n/2];
c[i][j] = a1[i+n/2][j];
d[i][j] = a1[i+n/2][j+n/2];
e[i][j] = a2[i][j];
f[i][j] = a2[i][j+n/2];
g[i][j] = a2[i+n/2][j];
h[i][j] = a2[i+n/2][j+n/2];
}
}
s1 = Strassen(a,sub(f,h,n/2),n/2);
s2 = Strassen(add(a,b,n/2),h,n/2);
s3 = Strassen(add(c,d,n/2),e,n/2);
s4 = Strassen(d,sub(g,e,n/2),n/2);
s5 = Strassen(add(a,d,n/2),add(e,h,n/2),n/2);
s6 = Strassen(sub(b,d,n/2),add(g,h,n/2),n/2);
s7 = Strassen(sub(a,c,n/2),add(e,f,n/2),n/2);
I = add(s5,add(s6,sub(s4,s2,n/2),n/2),n/2);
J = add(s1,s2,n/2);
k = add(s3,s4,n/2);
l = add(sub(sub(s1,s7,n/2),s3,n/2),s5,n/2);
for(int i =0;i<n/2;i++){
for(int j=0;j<n/2;j++){
x[i][j] = I[i][j];
x[i][j+n/2] = J[i][j];
x[i+n/2][j] = k[i][j];
x[i+n/2][j+n/2] = l[i][j];
}
}
if(n==N){
for(int i=0;i<n/2;i++){
for(int j=0;j<n/2;j++){
S1[i][j] = s1[i][j];
S2[i][j] = s2[i][j];
S3[i][j] = s3[i][j];
S4[i][j] = s4[i][j];
S5[i][j] = s5[i][j];
S6[i][j] = s6[i][j];
S7[i][j] = s7[i][j];
}
}
}
return x;
}
void print (int a [N/2][N/2],int n){
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
printf("%d ",a[i][j]);
}
printf("\n");
}
}
int main(void){
int **a1 = (int **)malloc(N*sizeof(int *));
int **a2 = (int **)malloc(N*sizeof(int *));
for(int i=0;i<N;i++){
a1[i]=(int *)malloc(N*sizeof(int));
a2[i]=(int *)malloc(N*sizeof(int));
}
for(int i=0;i<N;i++){
for(int j=0;j<N;j++){
a1[i][j] = 1;
a2[i][j] = 1;
}
}
Strassen(a1,a2,N);
print (S1,N/2);
print (S2,N/2);
print (S3,N/2);
print (S4,N/2);
print (S5,N/2);
print (S6,N/2);
print (S7,N/2);
}