你说我怎么这么菜呢?
T1:
考场上没看到下取整,同时把那个sigma看成了次数,以为要在mod phi(p)下算,这玩意能做?
然后果断爆零......
其实puts(2的逆元)+puts("1")有25分......
正解看官方题解吧。
那个鬼畜的容斥大概就是在n+1个断点中钦定k个断点,数字相对于剩下的m+1-k个集合的位置无关,所以对于这m+1-k个集合可以随意放置。
代码:
1 #include<cstdio> 2 #include<algorithm> 3 typedef long long int lli; 4 const int maxn=1e5+1e2,maxl=262145; 5 const int mod=998244353,g=3; 6 7 lli fac[maxn],inv[maxn]; 8 lli lft[maxl],rit[maxl],tar[maxl]; 9 lli ans; 10 int n,m; 11 12 inline lli fastpow(lli base,int tim) { 13 lli ret = 1; 14 while( tim ) { 15 if( tim & 1 ) ret = ret * base % mod; 16 if( tim >>= 1 ) base = base * base % mod; 17 } 18 return ret; 19 } 20 inline void NTT(lli* dst,int n,int ope=1) { 21 for(int i=0,j=0;i<n;i++) { 22 if( i < j ) std::swap(dst[i],dst[j]); 23 for(int t=n>>1;(j^=t)<t;t>>=1); 24 } 25 for(int len=2;len<=n;len<<=1) { 26 const int h = len >> 1; 27 lli per = fastpow(g,mod/len); 28 if( !~ope ) per = fastpow(per,mod-2); 29 for(int st=0;st<n;st+=len) { 30 lli w = 1; 31 for(int pos=0;pos<h;pos++) { 32 const lli u = dst[st+pos] , v = dst[st+pos+h] * w % mod; 33 dst[st+pos] = ( u + v ) % mod , 34 dst[st+pos+h] = ( u - v + mod ) % mod , 35 w = w * per % mod; 36 } 37 } 38 } 39 if( !~ope ) { 40 const lli inv = fastpow(n,mod-2); 41 for(int i=0;i<n;i++) dst[i] = dst[i] * inv % mod; 42 } 43 } 44 45 inline lli c(int n,int m) { 46 return fac[n] * inv[m] % mod * inv[n-m] % mod; 47 } 48 inline void init() { 49 for(int i=0;i<=n+1;i++) { 50 lft[i] = ( i & 1 ) ? mod - c(n+1,i) : c(n+1,i); 51 rit[i] = fastpow(i,n); 52 } 53 } 54 inline void getans() { 55 init(); int len = 1; 56 while( len <= ( n << 1 ) + 1 ) len <<= 1; 57 NTT(lft,len) , NTT(rit,len); 58 for(int i=0;i<len;i++) tar[i] = lft[i] * rit[i] % mod; 59 NTT(tar,len,-1); 60 for(int i=0;i<=n;i++) ans = ( ans + tar[i+1] * fastpow(i,m) % mod ) % mod; 61 ans = ans * inv[n] % mod; 62 } 63 inline void pre() { 64 int lim = n + 1; 65 *fac = 1; 66 for(int i=1;i<=lim;i++) fac[i] = fac[i-1] * i % mod; 67 inv[lim] = fastpow(fac[lim],mod-2); 68 for(int i=lim;i;i--) inv[i-1] = inv[i] * i % mod; 69 } 70 71 int main() { 72 scanf("%d%d",&n,&m); 73 pre() , getans(); 74 printf("%lld\n",ans); 75 return 0; 76 }