题目链接

题目大意:给定一段长为\(m\)的数\(S\),求有多少个长为\(n\)的数不包含子串\(S\)

\(\text{KMP}\)、计数、矩阵乘法


分析:

首先由于允许前导\(0\),一共有\(10^n\)个串。反着来,我们考虑有多少个串包含子串\(S\)

我们记\(f(n,s)\)表示长为\(n\),后缀最长能匹配\(S\)长为\(s\)的前缀的串个数

考虑\(f(n,s)\)会对哪些位置产生贡献,我们枚举第\(n+1\)个位置为\(c\)

如果\(S[s+1]=c\),那么\(f(n,s)\)的值应当被累加到\(f(n+1,s+1)\)

如果\(S[s+1]\neq c\),那么我们应当用\(\text{KMP}\)算法不断跳\(\text{fail}\),找到转移位置。为了便于转移,以及优化运行时间,用类似\(\text{AC}\)自动机补全\(\text{Trie}\)树的方法建出转移图

设补全后的转移数组为\(ch\),两者可以统一

\(f(n,s)\)会对\(f(n+1,ch[s][c])\)产生贡献,其中\(c\in[0,9]\)

先考虑计数,一个比较\(naive\)的想法是求\(\sum_n f(n,m)\),这样会有重复计数

也就是说有可能同一个串包含子串\(S\)多次

不妨规定第一次包含子串\(S\)时计数,那么已经包含子串\(S\)之后,后面的所有位置都可以任取了。对于任意\(s=m\)\(f(n,s)\),没必要将它的贡献累计到后面。

暴力算法:

求出\(f(n,s)\quad s\in[0,m]\),令\(ans=ans*10+f(n,m)\),对于所有\(f(n,s) \quad s\in[0,m)\)进行转移,计算它对于位置\(n+1\)的贡献

这样是\(O(n)\)

可以用矩阵乘法优化

假设我们有长为\(m+1\)的数组\(f\),表示\(f(n,s)\quad s\in[0,m]\),由上分析,我们可以用\(f[m]\)表示答案(从\(0\)开始),枚举\(s\in[0,m),c\in[0,9]\),把转移矩阵第\(s\)行第\(ch[s][c]\)\(+1\)

最后把转移矩阵第\(m\)行第\(m\)列置为\(10\)(第一次包含子串\(S\),后面有\(k\)位任取,答案要乘\(10^k\)),快速幂转移即可

#include <cstdio>
#include <cstring>
using namespace std;
const int maxm = 32;
int n,m,mod,ans,ch[maxm][10],fail[maxm];
inline int mul(int a,int b){return (1ll * a * b) % mod;}
inline int add(int a,int b){return (a + b) % mod;}
inline int sub(int a,int b){return (((a - b) % mod) + mod) % mod;}
inline int qpow(int a,int b){
	int res = 1,base = a;
	while(b){
		if(b & 1)res = mul(res,base);
		base = mul(base,base);
		b >>= 1;
	}
	return res;
}
struct matrix{
	int f[maxm][maxm];
	int x,y;
	void clear(){
		memset(f,0,sizeof(f));
		x = y = 0;
	}
	matrix operator * (const matrix &rhs)const{	
		matrix res;res.clear();
		res.x = x,res.y = rhs.y;
		for(int i = 0;i < x;i++)
			for(int k = 0;k < y;k++)
				for(int j = 0;j < rhs.y;j++)
					res.f[i][j] = add(res.f[i][j],mul(f[i][k],rhs.f[k][j]));
		return res;
	}
}w,org;
inline matrix qpow(matrix base,int b){
	matrix res;res.clear();
	res.x = res.y = base.x;
	for(int i = 0;i < res.x;i++)res.f[i][i] = 1;
	while(b){
		if(b & 1)res = res * base;
		base = base * base;
		b >>= 1;
	}
	return res;
}
inline int idx(char c){return c - '0';} 
char str[maxm];
int main(){
	scanf("%d %d %d",&n,&m,&mod);
	scanf("%s",str + 1);
	for(int u = 0;u < m;u++)
		ch[u][idx(str[u + 1])] = u + 1;
	for(int u = 1;u <= m;u++)
		for(int c = 0;c < 10;c++)
			if(ch[u][c])fail[ch[u][c]] = ch[fail[u]][c];
			else ch[u][c] = ch[fail[u]][c];
	w.x = w.y = m + 1;
	for(int s = 0;s < m;s++)
		for(int c = 0;c < 10;c++)
			w.f[s][ch[s][c]]++;
	w.f[m][m] = 10;
	org.x = 1,org.y = m + 1;
	org.f[0][0] = 1;
	org = org * qpow(w,n);
	ans = qpow(10,n);
	ans = sub(ans,org.f[0][m]);
	printf("%d\n",ans);
	return 0;
}

相关文章:

  • 2021-12-06
猜你喜欢
  • 2021-12-21
  • 2021-07-19
  • 2022-03-06
  • 2021-10-09
  • 2021-06-16
  • 2021-12-28
  • 2021-08-09
相关资源
相似解决方案