今天练了不少快速幂的手,这一直是之前的一个漏洞吧,现在把洞补上。东西是挺简单的东西,当然题目多变,做起来也问题多多。
首先放一下核心代码:
//矩阵乘法 const int mod = 10000; const int maxn = 2; struct matrix { int a[maxn][maxn]; }; matrix mul(matrix A, matrix B) { matrix ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 0; i < maxn; ++i) for(int k = 0; k < maxn; ++k) if(A.a[i][k]) //注意此处的优化,一般矩阵复杂度还是O(n^3),然而当矩阵是稀疏矩阵,即存在很多0时,复杂度则甚至可能降为O(n^2); for(int j = 0; j < maxn; ++j) { ret.a[i][j] += A.a[i][k] * B.a[k][j]; if(ret.a[i][j] >= mod) ret.a[i][j] %= mod; } return ret; }
//快速幂计算。二分原理: a^k = (a^2)^(k/2) = ((a^2)^2)^(k/4); matrix expo(matrix p, int k) { if(k == 1) return p; matrix ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 0; i < maxn; ++i) ret.a[i][i] = 1; if(k == 0) return ret; while(k) { if(k & 1) ret = mul(p, ret); p = mul(p, p); k >>= 1; } return ret; }
对于此处代码可以将幂转化成二进制形式来理解:例如当k=156时,156 = 10011100 = 128 + 16 + 8 + 4; ans = a156 = a128 * a16 * a8 * a4
从右向左每一位i(i >= 0)即ai,碰见一个1就把ai乘到ans里;
1 while(k) 2 { 3 if(k & 1) 4 ret = mul(p, ret); 5 p = mul(p, p); 6 k >>= 1; 7 }
结构体形式模板:
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 using namespace std; 5 typedef long long ll; 6 int n, k, mod; 7 const int maxn = 100; 8 struct matrix 9 { 10 int a[maxn][maxn]; 11 void print() 12 { 13 for(int i = 0; i < n; i++) 14 { 15 for(int j = 0; j < n; j++) 16 { 17 if(j) printf(" "); 18 printf("%d", a[i][j] % mod); 19 } 20 printf("\n"); 21 } 22 } 23 matrix& operator += (const matrix& rhs) 24 { 25 for(int i = 0; i < n; ++i) 26 for(int j = 0; j < n; ++j) 27 if(rhs.a[i][j]) 28 { 29 a[i][j] += rhs.a[i][j]; 30 if(a[i][j] >= mod) a[i][j] %= mod; 31 } 32 return *this; 33 } 34 matrix& operator *= (const matrix& rhs) 35 { 36 matrix ret; 37 memset(ret.a, 0, sizeof(ret.a)); 38 for(int i = 0; i < n; ++i) 39 for(int k = 0; k < n; ++k) 40 if(a[i][k]) 41 for(int j = 0; j < n; ++j) 42 { 43 ret.a[i][j] += a[i][k] * rhs.a[k][j]; 44 if(ret.a[i][j] >= mod) 45 ret.a[i][j] %= mod; 46 } 47 memcpy(a, ret.a, sizeof(a)); 48 return *this; 49 } 50 }; 51 matrix expo(matrix p, int k) 52 { 53 if(k == 1) return p; 54 matrix ret; 55 memset(ret.a, 0, sizeof(ret.a)); 56 for(int i = 0; i < n; ++i) 57 ret.a[i][i] = 1; 58 if(k == 0) return ret; 59 while(k) 60 { 61 if(k & 1) 62 ret *= p; 63 p *= p; 64 k >>= 1; 65 } 66 return ret; 67 }