平时有关线性递推的题,很多都可以利用矩阵乘法来解决。 时间复杂度一般是O(K3logn)因此对矩阵的规模限制比较大。

下面介绍一种利用利用Cayley-Hamilton theorem加速矩阵乘法的方法。

 

Cayley-Hamilton theorem:

记矩阵A的特征多项式为f(x)。 则有f(A)=0.

证明可以看 维基百科 https://en.wikipedia.org/wiki/Cayley–Hamilton_theorem#A_direct_algebraic_proof

另外我在高等代数的课本上找到了 证明(和维基百科里的第一种证明方法是一样的)

利用Cayley-Hamilton theorem 优化矩阵线性递推

利用Cayley-Hamilton theorem 优化矩阵线性递推

 

 

 

下面介绍几个 利用可以这个定理解决的题目:

1. project euler 258    

利用Cayley-Hamilton theorem 优化矩阵线性递推

显然可以用矩阵乘法来做。下面讲一下怎么利用Cayley-Hamilton theorem 来优化: 详细的论述可以参考这篇

设M为K阶矩阵,主要思想就是将$M^n$表示为 $b_0 M^0\ +\ b_1 M^1\ +\ \cdots\ b_{K-1}M^{K-1}$这样的形式.

根据Cayley-Hamilton theorem $M^K\ =\ a_0 M^0\ +\ a_1 M^1\ +\ \cdots\ a_{K-1}M^{K-1}$    

由于转移矩阵的特殊性,不难证明$a_i$恰好是线性递推公式里的系数。

假设我们已经将$M^n$表示为 $b_0 M^0\ +\ b_1 M^1\ +\ \cdots\ b_{K-1}M^{K-1}$这样的形式,不难得到$M^{n+1}$的表示法。只要将$M^n$乘个M之后得到的项中$M^K$拆成小于K次的线性组合就好了。  这样我们可以预处理出$M^0\ M^1\ \cdots\ M^{2K-2}$的表示法。

对于次数更高的, $M^{i+j}=M^i*M^j$  可以看成是两个多项式的乘法。 利用快速幂 可以在O(K2logn)的时间求出$M^n$的表示法.

利用Cayley-Hamilton theorem 优化矩阵线性递推

 

另外有一个优化常数的trick, 可以预处理出$M^1$ $M^2$  $M^4$  $M^8$....  $M^{2^r}$这些项, 对于$M^n$只要根据二进制位相应的乘上这些项就好了。 这样做比直接做快速幂快一倍(少了一半的多项式乘法操作)。

 

参考代码:

 1 //ans=12747994 
 2 #include <cstdio>
 3 #include <iostream>
 4 #include <queue>
 5 #include <algorithm>
 6 #include <cstring>
 7 #include <set>
 8 using namespace std;
 9 
10 #define N 2000
11 typedef long long ll;
12 
13 const int Mod=20092010;
14 int a[N],f[N<<1];
15 int k=2000;
16 
17 
18 
19 //基本思想是把A^n 表示成A^0  A^1 A^2 ... A^(k-1)的线性组合
20 //A^(p+q)可看成两个多项式相乘,只要实现预处理出A^0  A^1 A^2 ... A^(2k-2)的多项式表示法 
21 //A^k可以根据特征多项式的性质得到 ,A^(n+1)次可以从A^n次得到  根据这个来预处理 
22 struct Poly
23 {
24     int b[N];
25 }P[N<<1];
26 
27 
28 Poly operator * (const Poly &A,const Poly &B)
29 {
30     Poly ans; memset(ans.b,0,sizeof(ans.b));
31     for (int i=0;i<=2*k-2;i++)
32     {
33         int res=0;
34         for (int j=max(0,i-k+1);j<k && j<=i;j++)
35         {
36             res+=1ll*A.b[j]*B.b[i-j]%Mod;
37             if (res>=Mod) res-=Mod;
38         }
39         if (i<k) {ans.b[i]=res; continue;}
40         
41         //把次数大于等于k的搞成小于k 
42         for (int j=0;j<k;j++)   
43         {
44             ans.b[j]+=1ll*res*P[i].b[j]%Mod;
45             if (ans.b[j]>=Mod) ans.b[j]-=Mod;
46         }
47     }
48     return ans;
49 }
50 
51 Poly Power_Poly(ll p)
52 {
53     if (p<=2*k-2) return P[p];
54     
55     Poly ans=P[0],A=P[1];
56     for (;p;p>>=1)
57     {
58         if (p&1) ans=ans*A;
59         A=A*A;
60     }
61     return ans;
62 }
63 
64 int main()
65 {
66     freopen("in.in","r",stdin);
67     freopen("out.out","w",stdout);
68     
69     //f[n]=a[k-1]f[n-1]....a[0]f[n-k]
70     a[0]=a[1]=1; ll n; n=1e18;
71     for (int i=0;i<k;i++) P[i].b[i]=1;
72     
73     //P[k]=a[0]P[0]+a[1]P[1]+....a[k-1]P[k-1]
74     for (int i=0;i<k;i++) P[k].b[i]=a[i]; 
75     
76     //Calculate P[k+1]...P[2k-2]
77     //using P[n+1]=a[0]*b[k-1]+ (a[1]*b[k-1]+b[0]) + (a[2]*b[k-1]+b[1]) +...(a[k-1]*b[k-1]+b[k-2])
78     for (int j=k+1;j<=2*k-2;j++)
79     {
80         P[j].b[0]=1ll*a[0]*P[j-1].b[k-1]%Mod;
81         for (int i=1;i<k;i++)
82             P[j].b[i]=(1ll*a[i]*P[j-1].b[k-1]%Mod+P[j-1].b[i-1])%Mod;
83     }
84     
85     Poly tmp=Power_Poly(n-k+1); int ans=0;
86     
87     for (int i=0;i<k;i++) f[i]=1;
88     for (int i=k;i<=2*k-2;i++) f[i]=(f[i-1999]+f[i-2000])%Mod;
89     
90     //A^n*X=b[0]*A^0*X+b[1]*A^1*X+...b[k-1]*A^(k-1)*X      A^i*X= {f[k-1+i]  f[k-2+i]...  f[0+i]}
91     for (int i=0;i<k;i++)
92     {
93         ans+=1ll*tmp.b[i]*f[k-1+i]%Mod;
94         if (ans>=Mod) ans-=Mod;
95     }
96     printf("%d\n",ans);
97     return 0;
98 }
View Code

 

 


 

2.设利用Cayley-Hamilton theorem 优化矩阵线性递推,求利用Cayley-Hamilton theorem 优化矩阵线性递推的值。其中利用Cayley-Hamilton theorem 优化矩阵线性递推利用Cayley-Hamilton theorem 优化矩阵线性递推

     利用Cayley-Hamilton theorem 优化矩阵线性递推

题目链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1229 

 

首先这个题可以用扰动法 搞出一个关于k的递推公式,在O(k2)的时间解决。

具体可以参考这篇。 虽然不是同一个式子,但是方法是一样的,扰动法在《具体数学》上也有介绍。因此本文不再赘述。

给个AC代码供参考:

 

  1 #include <cstdio>
  2 #include <iostream>
  3 #include <queue>
  4 #include <algorithm>
  5 #include <cstring>
  6 #include <set>
  7 using namespace std;
  8 
  9 #define N 2010
 10 typedef long long ll;
 11 const int Mod=1000000007;
 12 
 13 int k;
 14 ll n,r;
 15 int f[N],inv[N];
 16 int fac[N],fac_inv[N];
 17 
 18 int C(int x,int y)
 19 {
 20     if (y==0) return 1;
 21     if (y>x) return 0;
 22     
 23     int res=1ll*fac[x]*fac_inv[y]%Mod;
 24     return 1ll*res*fac_inv[x-y]%Mod;
 25 }
 26 
 27 int Power(ll a,ll p)
 28 {
 29     int res=1; a%=Mod;
 30     for (;p;p>>=1)
 31     {
 32         if (p&1) res=1ll*res*a%Mod;
 33         a=a*a%Mod;
 34     }
 35     return res;
 36 }
 37 
 38 int Solve1()
 39 {
 40     f[0]=n;  int t=n+1;
 41     for (int i=1;i<=k;i++)
 42     {
 43         f[i]=t=1ll*t*(n+1)%Mod;
 44         for (int j=0;j<i;j++)
 45         {
 46             f[i]+=Mod-1ll*C(i+1,j)*f[j]%Mod;
 47             if (f[i]>=Mod) f[i]-=Mod;
 48         }
 49         f[i]--; if (f[i]<0) f[i]+=Mod;
 50         f[i]=1ll*f[i]*inv[i+1]%Mod;
 51     }
 52     return f[k];
 53 }
 54 
 55 
 56 int Solve2()
 57 {
 58     f[0]=Power(r,n+1)-r%Mod;
 59     if (f[0]<0) f[0]+=Mod;
 60     f[0]=1ll*f[0]*Power(r-1,Mod-2)%Mod;
 61     
 62     for (int i=1;i<=k;i++)
 63     {
 64         f[i]=1ll*Power(n+1,i)*Power(r,n+1)%Mod;
 65         f[i]-=r%Mod; if (f[i]<0) f[i]+=Mod;
 66         
 67         int tmp=0;
 68         for (int j=0;j<i;j++)
 69         {
 70             tmp+=1ll*C(i,j)*f[j]%Mod;
 71             if (tmp>=Mod) tmp-=Mod;
 72         }
 73         f[i]-=1ll*(r%Mod)*tmp%Mod;
 74         if (f[i]<0) f[i]+=Mod;
 75         f[i]=1ll*f[i]*Power(r-1,Mod-2)%Mod;
 76         //cout<<i<<" "<<f[i]<<endl;
 77     }
 78     return f[k];
 79 }
 80 
 81 
 82             
 83 int main()
 84 {
 85     //freopen("in.in","r",stdin);
 86     //freopen("out.out","w",stdout);
 87     
 88     inv[1]=1; for (int i=2;i<N;i++) inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod;
 89     fac[0]=1; for (int i=1;i<N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
 90     fac_inv[0]=1; for (int i=1;i<N;i++) fac_inv[i]=1ll*fac_inv[i-1]*inv[i]%Mod;
 91     
 92     int T; scanf("%d",&T);
 93     while (T--)
 94     {
 95         cin >> n >> k >> r;
 96         if (r==1) printf("%d\n",Solve1());
 97         else printf("%d\n",Solve2());
 98     }
 99     
100     return 0;
101 }
View Code

相关文章: