题目大意:给定一段长为\(m\)的数\(S\),求有多少个长为\(n\)的数不包含子串\(S\)
\(\text{KMP}\)、计数、矩阵乘法
分析:
首先由于允许前导\(0\),一共有\(10^n\)个串。反着来,我们考虑有多少个串包含子串\(S\)
我们记\(f(n,s)\)表示长为\(n\),后缀最长能匹配\(S\)长为\(s\)的前缀的串个数
考虑\(f(n,s)\)会对哪些位置产生贡献,我们枚举第\(n+1\)个位置为\(c\)
如果\(S[s+1]=c\),那么\(f(n,s)\)的值应当被累加到\(f(n+1,s+1)\)上
如果\(S[s+1]\neq c\),那么我们应当用\(\text{KMP}\)算法不断跳\(\text{fail}\),找到转移位置。为了便于转移,以及优化运行时间,用类似\(\text{AC}\)自动机补全\(\text{Trie}\)树的方法建出转移图
设补全后的转移数组为\(ch\),两者可以统一
\(f(n,s)\)会对\(f(n+1,ch[s][c])\)产生贡献,其中\(c\in[0,9]\)
先考虑计数,一个比较\(naive\)的想法是求\(\sum_n f(n,m)\),这样会有重复计数
也就是说有可能同一个串包含子串\(S\)多次
不妨规定第一次包含子串\(S\)时计数,那么已经包含子串\(S\)之后,后面的所有位置都可以任取了。对于任意\(s=m\)的\(f(n,s)\),没必要将它的贡献累计到后面。
暴力算法:
求出\(f(n,s)\quad s\in[0,m]\),令\(ans=ans*10+f(n,m)\),对于所有\(f(n,s) \quad s\in[0,m)\)进行转移,计算它对于位置\(n+1\)的贡献
这样是\(O(n)\)的
可以用矩阵乘法优化
假设我们有长为\(m+1\)的数组\(f\),表示\(f(n,s)\quad s\in[0,m]\),由上分析,我们可以用\(f[m]\)表示答案(从\(0\)开始),枚举\(s\in[0,m),c\in[0,9]\),把转移矩阵第\(s\)行第\(ch[s][c]\)列\(+1\)
最后把转移矩阵第\(m\)行第\(m\)列置为\(10\)(第一次包含子串\(S\),后面有\(k\)位任取,答案要乘\(10^k\)),快速幂转移即可
#include <cstdio>
#include <cstring>
using namespace std;
const int maxm = 32;
int n,m,mod,ans,ch[maxm][10],fail[maxm];
inline int mul(int a,int b){return (1ll * a * b) % mod;}
inline int add(int a,int b){return (a + b) % mod;}
inline int sub(int a,int b){return (((a - b) % mod) + mod) % mod;}
inline int qpow(int a,int b){
int res = 1,base = a;
while(b){
if(b & 1)res = mul(res,base);
base = mul(base,base);
b >>= 1;
}
return res;
}
struct matrix{
int f[maxm][maxm];
int x,y;
void clear(){
memset(f,0,sizeof(f));
x = y = 0;
}
matrix operator * (const matrix &rhs)const{
matrix res;res.clear();
res.x = x,res.y = rhs.y;
for(int i = 0;i < x;i++)
for(int k = 0;k < y;k++)
for(int j = 0;j < rhs.y;j++)
res.f[i][j] = add(res.f[i][j],mul(f[i][k],rhs.f[k][j]));
return res;
}
}w,org;
inline matrix qpow(matrix base,int b){
matrix res;res.clear();
res.x = res.y = base.x;
for(int i = 0;i < res.x;i++)res.f[i][i] = 1;
while(b){
if(b & 1)res = res * base;
base = base * base;
b >>= 1;
}
return res;
}
inline int idx(char c){return c - '0';}
char str[maxm];
int main(){
scanf("%d %d %d",&n,&m,&mod);
scanf("%s",str + 1);
for(int u = 0;u < m;u++)
ch[u][idx(str[u + 1])] = u + 1;
for(int u = 1;u <= m;u++)
for(int c = 0;c < 10;c++)
if(ch[u][c])fail[ch[u][c]] = ch[fail[u]][c];
else ch[u][c] = ch[fail[u]][c];
w.x = w.y = m + 1;
for(int s = 0;s < m;s++)
for(int c = 0;c < 10;c++)
w.f[s][ch[s][c]]++;
w.f[m][m] = 10;
org.x = 1,org.y = m + 1;
org.f[0][0] = 1;
org = org * qpow(w,n);
ans = qpow(10,n);
ans = sub(ans,org.f[0][m]);
printf("%d\n",ans);
return 0;
}