线性递推的题目区域赛里还是挺多的,还是有必要学一下
~ BM(Berlekamp-Massey)算法 ~
有一个$n$阶线性递推$f$,想要计算$f(m)$,有一种常用的办法是矩阵快速幂,复杂度是$O(n^3logm)$
在不少情况下这已经够用了,但是如果$n$比较大、到了$10^3$级别,这就不太适用了
而BM算法能将这个复杂度压低到$O(n^2logm)$,若加上NTT优化的话能做到$O(n^2+nlognlogm)$,十分厉害
这个算法的核心是将$f(m)$用递推的前$n$项表示
即,已知$f(0),...,f(n-1)$和递推式$f(m)=a_0f(m-1)+...+a_{n-1}f(m-n)$,该算法是求出系数$W_0,...,W_{n-1}$,使得$f(m)=W_0f(n-1)+...+W_{n-1}f(0)$
看似无从下手?实际上只要大力展开就行了
根据定义,有(只是写成$\sum$的形式而已)
\[f(m)=\sum_{i=0}^{n-1}a_i f(m-1-i)\]
而对于每一项再次展开,即
\[f(m-1-i)=\sum_{j=0}^{n-1}a_j f(m-1-i-1-j)\]
全部代入,能得到
\[f(m)=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_ia_j f(m-2-i-j)\]
把式子写的更好看一点,就是
\[f(m)=\sum_{k=0}^{2n-2}\sum_{i+j=k}a_ia_j f(m-2-k)\]
这样做之后有什么用呢?
在原本的递推式中,$f(m)$可以通过$f(m-1),...,f(m-n)$这$n$个项表示
各项展开后,就可以通过$f(m-2),...,f(m-2n)$表示
事实上,我们可以再依次对$f(m-i),2\leq i\leq n$展开,并将系数向$f(m-i-1),...,f(m-i-n)$并入,最终就能把原递推式通过$f(m-n-1),...,f(m-2n)$这$n$项表示
于是可以得到一个新的$n$阶递推式,记为$f(m)=b_0f(m-n+1),...,b_{n-1}f(m-2n)$
再用新递推式将各项展开,就可以通过$f(m-2n-2),...,f(m-4n)$表示
再用原递推式展开$f(m-2n-i),2\leq i\leq n$并向前合并系数,最终就能把原递推式通过$f(m-3n+1),...,f(m-4n)$这$n$项表示
之后都是类似的了,不再赘述
有了上面的思路,就可以用类似快速幂的方法,得到$f(m)=W_0f(m-(k-1)n+1),...,W_{n-1}f(m-kn)$这样的展开式,其中$m-kn<n$
余数$m-kn$是我们不喜欢的,但也没有必要整体再向前推,一开始计算时算出$f(0),...,f(2n-1)$就够了
按照上述思路能这样实现:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int MOD=1000000007; const int N=1005; int n,m; int a[N]; int f[N<<1]; int tmp[N<<1]; void mul(int *y,int *x) { memset(tmp,0,sizeof(tmp)); for(int i=0;i<n;i++) for(int j=0;j<n;j++) tmp[i+j]=(tmp[i+j]+ll(y[i])*x[j])%MOD; for(int i=0;i<n-1;i++) for(int j=0;j<n;j++) tmp[i+j+1]=(tmp[i+j+1]+ll(tmp[i])*a[j])%MOD; for(int i=0;i<n;i++) y[i]=tmp[i+n-1]; } int w[N<<1],x[N<<1]; int BM() { if(m<(n<<1)) return f[m]; for(int i=0;i<n;i++) x[i]=a[i],w[i]=a[i]; int t=(m-n)/n; int rem=m-n-t*n; while(t) { if(t&1) mul(w,x); mul(x,x); t>>=1; } int res=0; for(int i=0;i<n;i++) res=(res+ll(w[i])*f[rem+n-i-1])%MOD; return res; } int main() { scanf("%d%d",&n,&m); for(int i=0;i<n;i++) scanf("%d",&a[i]); for(int i=0;i<n;i++) scanf("%d",&f[i]); for(int i=n;i<(n<<1);i++) for(int j=1;j<=n;j++) f[i]=(f[i]+ll(a[j-1])*f[n-j])%MOD; printf("%d\n",BM()); return 0; }