一、简介

前置知识:多项式乘法与 FFT

FFT 涉及大量 double 类型数据操作和 \(\sin,\cos\) 运算,会产生误差。快速数论变换(Number Theoretic Transform,简称 NTT)在 FFT 的基础上,优化了常数及误差。

NTT 其实就是把 FFT 中的单位根换成了原根。

NTT 解决的是多项式乘法带模数的情况,可以说有些受模数的限制,多项式系数应为整数。

二、原根 与 NTT

「算法笔记」基础数论 2 中提及了原根的部分内容。

对于质数 \(p\),若 \(g\)\(p\) 的原根,则 \(g^i\bmod p\,(0\leq i<p)\) 互不相同。

考虑可以表示为 \(p=a\cdot 2^k+1\) 的质数 \(p\)。NTT 的模数一般选取这样符合要求的 \(p\)。比较常见的 \(p\)\(998244353=119\cdot 2^{23}+1\)\(1004535809=479\cdot 2^{21}+1\),它们的原根都是 \(3\)

NTT 与 FFT 几乎一样,只不过 FFT 中代入的是 \(\omega_n^k\),而 NTT 中代入的是 \({(g^{\frac{p-1}{n}})}^k\)

\({(g^{\frac{p-1}{n}})}^k\) 满足 FFT 中所用到的 \(\omega_n^k\) 拥有的性质。

结论:\(\omega_n^k\equiv {(g^{\frac{p-1}{n}})}^k\pmod p\),可以把 \({(g^{\frac{p-1}{n}})}^k\) 看成是 \(\omega_n^k\) 的等价。证明略。

由于 \(p\) 可以表示为 \(p=a\cdot 2^k+1\) 的形式,并且多项式项数 \(n\) 已被我们补为 \(2\) 的幂次,所以 \(\frac{p-1}{n}\) 一定为整数(注意 \(n\leq 2^k\),不然会出问题)。

代码只需在 FFT 的基础上稍作修改即可。复杂度同样为 \(\mathcal{O}(n\log n)\)

//Luogu P3803 
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e6+5,mod=998244353;
int n,m,a[N],b[N],len,r[N],inv;
int mul(int x,int n,int mod){
    int ans=mod!=1;
    for(x%=mod;n;n>>=1,x=x*x%mod)
        if(n&1) ans=ans*x%mod;
    return ans;
}
void NTT(int a[N],int n,int opt){    //opt=1/-1: DFT/IDFT
    for(int i=0;i<n;i++)
        if(i<r[i]) swap(a[i],a[r[i]]);
    for(int k=2;k<=n;k<<=1){
        int m=k>>1,x=mul(3,(mod-1)/k,mod),w=1,v; 
        if(opt==-1) x=mul(x,mod-2,mod);
        for(int i=0;i<n;i+=k,w=1)
            for(int j=i;j<i+m;j++) v=w*a[j+m]%mod,a[j+m]=(a[j]-v+mod)%mod,a[j]=(a[j]+v)%mod,w=w*x%mod;
    }
    if(opt==-1){
        inv=mul(len,mod-2,mod);
        for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
    }
} 
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++)
        scanf("%lld",&b[i]);
    n=n+m+1;
    for(len=1;len<n;len<<=1); 
    for(int i=0;i<len;i++)
        r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
    NTT(a,len,1),NTT(b,len,1);
    for(int i=0;i<len;i++) a[i]=a[i]*b[i]%mod;
    NTT(a,len,-1);
    for(int i=0;i<n;i++)
        printf("%lld%c",a[i],i==n-1?'\n':' ');
    return 0;
}

Update:改了改后的板子→link

相关文章:

  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-11-30
  • 2021-09-29
  • 2021-10-30
猜你喜欢
  • 2022-12-23
  • 2022-12-23
  • 2021-05-04
  • 2022-01-22
  • 2022-02-26
  • 2022-02-20
相关资源
相似解决方案