题目:https://loj.ac/problem/3055
先写了暴力。本来想的是 n<=300 的那个在树上暴力维护好整个字符串, x=1 的那个用主席树维护好字符串和 nxt 数组。但 x=1 的部分会 TLE ,而且似乎不太对的样子。
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define ll long long #define pb push_back #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int n; namespace S1{ const int N=305,M=N*N; int fa[N],q[N],s[M],nxt[M]; ll sm[N]; vector<int> c[N],nt[N]; void add(int x,int y,int m=0,int ch=0) { sm[y]=sm[x]; fa[y]=x; if(!m)return; int top=0, cr=x; while(cr)q[++top]=cr,cr=fa[cr]; int tot=0; for(int i=top;i;i--) { cr=q[i]; for(int j=0,lm=c[cr].size();j<lm;j++) s[++tot]=c[cr][j], nxt[tot]=nt[cr][j]; } c[y].resize(m); nt[y].resize(m); int i,j; if(!tot){ s[1]=c[y][0]=ch;i=2;j=2;} else { i=tot+1;j=1;} for(;j<=m;j++,i++) { s[i]=ch; cr=nxt[i-1]; while(cr&&s[cr+1]!=ch)cr=nxt[cr]; if(s[cr+1]==ch)nxt[i]=cr+1; else nxt[i]=0; c[y][j-1]=ch; nt[y][j-1]=nxt[i]; sm[y]+=nxt[i]; } } void solve() { int op,x; char ch[5]; for(int i=1;i<=n;i++) { op=rdn();x=rdn(); if(op==1) { scanf("%s",ch); add(i-1,i,x,ch[0]-'a'+1);} else add(x,i); printf("%lld\n",sm[i]); } } } namespace S2{ const int N=1e5+5,M=2e6+5; int rt[N],tot,Ls[M],Rs[M],cd[N]; ll sm[N]; struct Node{ int c,nxt;}a[M]; int ins(int l,int r,int &cr,int pr,int p,int ch) { if(!cr){cr=++tot;ls=Ls[pr];rs=Rs[pr];} if(l==r){a[cr].c=ch;return cr;} int mid=l+r>>1; if(p<=mid)return ins(l,mid,ls,Ls[pr],p,ch); return ins(mid+1,r,rs,Rs[pr],p,ch); } Node qry(int l,int r,int cr,int p) { if(l==r)return a[cr]; int mid=l+r>>1; if(p<=mid)return qry(l,mid,ls,p); return qry(mid+1,r,rs,p); } void add(int cr,int pr,int m,int ch) { sm[cr]=sm[pr]; cd[cr]=cd[pr]; for(int i=1,d;i<=m;i++) { cd[cr]++; d=ins(1,n,rt[cr],rt[pr],cd[cr],ch); int p=qry(1,n,rt[cr],cd[cr]-1).nxt; while(p&&qry(1,n,rt[cr],p+1).c!=ch) p=qry(1,n,rt[cr],p).nxt; if(p+1!=cd[cr]&&qry(1,n,rt[cr],p+1).c==ch)//!= a[d].nxt=p+1; else a[d].nxt=0; sm[cr]+=a[d].nxt; } } void solve() { int op,x; char ch[5]; for(int i=1;i<=n;i++) { op=rdn();x=rdn(); if(op==1) { scanf("%s",ch); add(i,i-1,x,ch[0]-'a'+1);} else {sm[i]=sm[x];rt[i]=rt[x];cd[i]=cd[x];} printf("%lld\n",sm[i]); } } } int main() { n=rdn(); if(n<=300){S1::solve();return 0;} if(n<=1e5){S2::solve();return 0;} return 0; }