快速数论变换(NTT)
一.相关概念:
1.剩余系: 所有整数模正整数n得到的结果组成的集合称为n的剩余系,n的剩余系即小于n的非负整数,记为
2.简化剩余系: 在n的剩余系中与n互质的元素的集合,称为n的简化剩余系,记为
3.欧拉函数: n的简化剩余系中元素的个数,称为欧拉函数,记为
3.原根: 对于互质的两个正整数g和n,如果g模n的阶为,则称x为n的原根.换句话说,即对于,使得,但则称g为n的原根.
如何求原根?
求n的原根,似乎只能暴力枚举,然后检测是否等于1,此处p为,x为的质因数.如果存在一个p使得等于1,则g不是原根,否则g为原根.
一般原根都很小,所以暴力枚举即可.
二.NTT
在FFT中,我们选择n次单位复根作为x的值,因为它们符合消去引理、折半引理和求和引理,于是可以采用分治的方法将时间复杂度将为.但是因为使用了复数,所以在精度上可能会有些误差.而且有的题目,多项式相乘是带模的乘法,此时我们不能使用FFT了.那能不能取整数值来作为x的值呢?其实,只要该值也像主n次单位复根那样符合消去引理、折半引理和求和引理呢,则他们也是采用分治来加速的.
在数论当中,我们发现某些数的简化剩余系在乘法运算下构成的群与n次单位复根在乘法运算的群有相似的性质.如果存在原根,则的简化剩余系.
设, ,也为2的幂,且
则易知满足:
1.消去引理
根据定义即可证明.
2.折半引理:
根据消去引理可以证明.
3.求和引理:
当m>2且k不是i的倍数时,下式成立:
根据等比数列求和公式可以证明.
所以,我们可以取m的原根,像FFT一样,在的时间内完成多项式乘法.
使用NTT时有些限制,如果多项式乘法不要求取模,则我们要找足够大的质数m,并且,要保证n大于多项式的项数,m还要大于多项式的系数.
如果多项式乘法是带模乘法,则只能用NTT,不能使用FFT.此时,若m为质数,则要求,n要大于多项式的系数.若m不为质数,则需要用中国剩余定理来做.
三.模板
使用NTT计算多项式的乘积
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
using namespace std;
#define LL long long int
#define MAXN 400005
LL num1[MAXN],num2[MAXN];
int n,m,len;
const LL G=3;
const LL MOD=998244353;
LL w,wn;
LL ksm(LL x,LL k){
LL res=1;
while(k)
{
if(k&1)
res=res*x%MOD;
x=x*x%MOD;
k>>=1;
}
return res;}
void change(LL *num,int len){
for(int i=1,j=len/2;i<len-1;i++)
{
if(i<j)swap(num[i],num[j]);
int k=len/2;
while(j>=k)
{
j-=k;
k/=2;
}
if(j<k)j+=k;
}}
void ntt(LL *num,int len,int flg){
for(int i=2;i<=len;i<<=1)
{
if(flg==1)w=ksm(G,(MOD-1)/i);
else w=ksm(G,MOD-1-(MOD-1)/i);
for(int j=0;j<len;j+=i)
{
wn=1;
for(int k=0;k<i/2;k++)
{
LL u=num[j+k],t=num[j+k+i/2]*wn%MOD;
num[j+k]=(u+t)%MOD;
num[j+k+i/2]=(u-t+MOD)%MOD;
wn=wn*w%MOD;
}
}
}
if(flg==-1)
{
LL ni=ksm(len,MOD-2);
for(int i=0;i<len;i++)
num[i]=num[i]*ni%MOD;
}
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
memset(num1,0,sizeof num1);
memset(num2,0,sizeof num2);
n++,m++;
for(int i=0;i<n;i++)
{
scanf("%lld",&num1[i]);
}
for(int i=0;i<m;i++)
scanf("%lld",&num2[i]);
len=1;
while(len<n+m-1) len<<=1;
change(num1,len);
change(num2,len);
ntt(num1,len,1);
ntt(num2,len,1);
for(int i=0;i<len;i++)
num1[i]=num1[i]*num2[i];
change(num1,len);
ntt(num1,len,-1);
int pos=0;
for(int i=0;i<n+m-1;i++)
printf("%lld ",num1[i]);
printf("\n");
}
return 0;
}