矩阵乘法是线性代数里面很常用的一种计算方式,但当矩阵的阶太大时,人为计算就很麻烦了,因此对矩阵乘法问题的解决是算法很重要的方面。

矩阵的表达方式

首先,我们要先解决矩阵的表示方式。毫无疑问的,矩阵乘法应该用二维数组表示,但二维数组不能作为函数的参数传递,因此我们应该换一种方式表示二维数组。这里可以用双重指针表示二维数组,代码如下:

int **a1 = (int **)malloc(N*sizeof(int *));	//N是矩阵阶数,这里为a1申请一个空间,空间大小是N个指针,表示a1的每个元素都是指针
//为a1的每个元素申请空间
for(int i=0;i<N;i++){
		a1[i]=(int *)malloc(N*sizeof(int));
}
//可以用a1[i][j]表示第i行第j列的元素

解决了矩阵的表达方式,下面介绍三种矩阵的乘法算法。

普通法

最普通的矩阵乘法,就是按照定义的方法去解决。代码如下。

int multi(int **a1,int **a2,int n){
	//定义一个新的矩阵(二维数组) 
	int **x = (int **)malloc(n*sizeof(int *));
	for(int i=0;i<n;i++){
		x[i]=(int *)malloc(n*sizeof(int));
	}
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			x[i][j] = 0;
			for(int k=0;k<n;k++){
				x[i][j]+=a1[i][k]*a2[k][j];
			}
		}
	}
	return x;
}

可以看到,这种方法的时间复杂度为O(n3),显然是不可取的。

分块法

算法上机(二)矩阵乘法和Strassen’s 算法
如图所示,这种方法的代码与下一种方法实现类似,不po出来
时间复杂度分析:T(n)=8T(n/2)+O(n);

Strassen’s 算法

Strassen’s 算法的思路见下图
算法上机(二)矩阵乘法和Strassen’s 算法
算法上机(二)矩阵乘法和Strassen’s 算法
对其时间复杂度分析:T(n)=7T(n/2)+O(n);
显然,虽然O的分析是相同的,但其系数由8变成了7,实际时间会少很多,当n够大时,甚至可以减半
代码如下:

#include <stdio.h>
#include <stdlib.h>
#define N 16

int S1[N/2][N/2];
int S2[N/2][N/2];
int S3[N/2][N/2];
int S4[N/2][N/2];
int S5[N/2][N/2];
int S6[N/2][N/2];
int S7[N/2][N/2];

int ** add (int **a1,int **a2,int n){
	int **c = (int **)malloc(n*sizeof(int *));
	for(int i=0;i<n;i++){
		c[i]=(int *)malloc(n*sizeof(int));
	}
	
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			c[i][j] = a1[i][j]+a2[i][j];
		}
	}
	return c;
}

int ** sub (int **a1,int **a2,int n){
	int **c = (int **)malloc(n*sizeof(int *));
	for(int i=0;i<n;i++){
		c[i]=(int *)malloc(n*sizeof(int));
	}
	
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			c[i][j] = a1[i][j]-a2[i][j];
		}
	}
	return c;
}



int ** Strassen (int **a1,int **a2,int n){
	int **x = (int **)malloc(n*sizeof(int *));
	for(int i=0;i<n;i++){
		x[i]=(int *)malloc(n*sizeof(int));
	}
	if(n==2){
		x[0][0] = a1[0][0]*a2[0][0]+a1[0][1]*a2[1][0];
		x[0][1] = a1[0][0]*a2[0][1]+a1[0][1]*a2[1][1];
		x[1][0] = a1[1][0]*a2[0][0]+a1[1][1]*a2[1][0];
		x[1][1] = a1[1][0]*a2[0][1]+a1[1][1]*a2[1][1];
		return x;
	}
	
	
	
	int ** a = (int **)malloc(n/2*sizeof(int *));
	int ** b = (int **)malloc(n/2*sizeof(int *));
	int ** c = (int **)malloc(n/2*sizeof(int *));
	int ** d = (int **)malloc(n/2*sizeof(int *));
	int ** e = (int **)malloc(n/2*sizeof(int *));
	int ** f = (int **)malloc(n/2*sizeof(int *));
	int ** g = (int **)malloc(n/2*sizeof(int *));
	int ** h = (int **)malloc(n/2*sizeof(int *));
	int ** I = (int **)malloc(n/2*sizeof(int *));
	int ** J = (int **)malloc(n/2*sizeof(int *));
	int ** k = (int **)malloc(n/2*sizeof(int *));
	int ** l = (int **)malloc(n/2*sizeof(int *));
	int ** s1 = (int **)malloc(n/2*sizeof(int *));
	int ** s2 = (int **)malloc(n/2*sizeof(int *));
	int ** s3 = (int **)malloc(n/2*sizeof(int *));
	int ** s4 = (int **)malloc(n/2*sizeof(int *));
	int ** s5 = (int **)malloc(n/2*sizeof(int *));
	int ** s6 = (int **)malloc(n/2*sizeof(int *));
	int ** s7 = (int **)malloc(n/2*sizeof(int *));
	
	
	
	for(int i=0;i<n/2;i++){
		a[i] = (int *)malloc(n/2*sizeof(int));
		b[i] = (int *)malloc(n/2*sizeof(int));
		c[i] = (int *)malloc(n/2*sizeof(int));
		d[i] = (int *)malloc(n/2*sizeof(int));
		e[i] = (int *)malloc(n/2*sizeof(int));
		f[i] = (int *)malloc(n/2*sizeof(int));
		g[i] = (int *)malloc(n/2*sizeof(int));
		h[i] = (int *)malloc(n/2*sizeof(int));
		I[i] = (int *)malloc(n/2*sizeof(int));
		J[i] = (int *)malloc(n/2*sizeof(int));
		k[i] = (int *)malloc(n/2*sizeof(int));
		l[i] = (int *)malloc(n/2*sizeof(int));
		s1[i] = (int *)malloc(n/2*sizeof(int));
		s2[i] = (int *)malloc(n/2*sizeof(int));
		s3[i] = (int *)malloc(n/2*sizeof(int));
		s4[i] = (int *)malloc(n/2*sizeof(int));
		s5[i] = (int *)malloc(n/2*sizeof(int));
		s6[i] = (int *)malloc(n/2*sizeof(int));
		s7[i] = (int *)malloc(n/2*sizeof(int));
	}
	
	
		
	for(int i=0;i<n/2;i++){
		for(int j=0;j<n/2;j++){
			a[i][j] = a1[i][j];
			b[i][j] = a1[i][j+n/2];
			c[i][j] = a1[i+n/2][j];
			d[i][j] = a1[i+n/2][j+n/2];
			e[i][j] = a2[i][j];
			f[i][j] = a2[i][j+n/2];
			g[i][j] = a2[i+n/2][j];
			h[i][j] = a2[i+n/2][j+n/2];
		}
	}
	
	
	s1 = Strassen(a,sub(f,h,n/2),n/2);
	s2 = Strassen(add(a,b,n/2),h,n/2);
	s3 = Strassen(add(c,d,n/2),e,n/2);
	s4 = Strassen(d,sub(g,e,n/2),n/2);
	s5 = Strassen(add(a,d,n/2),add(e,h,n/2),n/2);
	s6 = Strassen(sub(b,d,n/2),add(g,h,n/2),n/2);
	s7 = Strassen(sub(a,c,n/2),add(e,f,n/2),n/2);
	
	I = add(s5,add(s6,sub(s4,s2,n/2),n/2),n/2);
	J = add(s1,s2,n/2);
	k = add(s3,s4,n/2);
	l = add(sub(sub(s1,s7,n/2),s3,n/2),s5,n/2);
	
	for(int i =0;i<n/2;i++){
		for(int j=0;j<n/2;j++){
			x[i][j] = I[i][j];
			x[i][j+n/2] = J[i][j];
			x[i+n/2][j] = k[i][j];
			x[i+n/2][j+n/2] = l[i][j];
		} 
	}
	
	if(n==N){
		for(int i=0;i<n/2;i++){
			for(int j=0;j<n/2;j++){
				S1[i][j] = s1[i][j];
				S2[i][j] = s2[i][j];
				S3[i][j] = s3[i][j];
				S4[i][j] = s4[i][j];
				S5[i][j] = s5[i][j];
				S6[i][j] = s6[i][j];
				S7[i][j] = s7[i][j];
			}
		}
	}
	return x;
}




void print (int a [N/2][N/2],int n){
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			printf("%d ",a[i][j]);
		}
		printf("\n");
	}
}

int main(void){
	int **a1 = (int **)malloc(N*sizeof(int *));
	int **a2 = (int **)malloc(N*sizeof(int *));
	for(int i=0;i<N;i++){
		a1[i]=(int *)malloc(N*sizeof(int));
		a2[i]=(int *)malloc(N*sizeof(int));
	}
	for(int i=0;i<N;i++){
		for(int j=0;j<N;j++){
			a1[i][j] = 1;
			a2[i][j] = 1;
		}
	}
	
	Strassen(a1,a2,N);
	print (S1,N/2);
	print (S2,N/2);
	print (S3,N/2);
	print (S4,N/2);
	print (S5,N/2);
	print (S6,N/2);
	print (S7,N/2);
	
}


相关文章:

  • 2021-11-23
  • 2021-04-16
  • 2021-12-08
  • 2021-12-12
  • 2021-07-20
  • 2022-12-23
  • 2021-11-23
猜你喜欢
  • 2021-11-23
  • 2021-12-21
  • 2021-08-29
  • 2021-07-08
  • 2021-11-23
  • 2021-11-23
相关资源
相似解决方案