FFT(快速傅里叶变换)与 NTT(快速数论变换)
多项式全家桶其之一。
前置知识:
复数(请移步高中必修二或 OI wiki)。
初中范围的多项式。
符号与约定
为了好看(以及好写),用
单位根
前置知识:复数。
众所周知,任意一元
考虑 \omega_n
),显然其它单位根就是
单位根有三个比较显然的引理:
- 消去引理:
,根据单位根几何意义显然。 - 折半引理:
,这个只在 为偶数时成立,带消去引理即可。 - 求和引理:
,这玩意显然就等于 ,是一个等比数列求和,答案即为 。
另外根据三角函数基础,显然
多项式
前置知识:初中范围的多项式。
此处只讨论一元
多项式有两种表示方法:
- 系数表示:即列举出每一项的系数
,其中 为 项的系数。 - 点值表示:根据经典结论,
个互不相同的点就能确定一个一元 次多项式,那么可以用 唯一确定一个多项式。并且这个神秘性质在定义域取到复数域,甚至是某个质数的剩余系时仍然成立,我证不来。
FFT
相关概念
- 离散傅里叶变换(Discrete Fourier Transform,DFT),将系数表示转化为某个点值表示的方法(实际定义比较复杂,涉及到一些信号学知识,此处可以这样理解)。
- 快速傅里叶变换(即 快速(离散)傅里叶变换,Fast (Discrete) Fourier Transform,FFT),复杂度较低的 DFT。
下文的
算法引入
假如我们要算多项式的积(
但是假如我们知道了
并且容易发现多项式乘法就是对系数的卷积,那么这玩意还能用来快速求卷积。
算法流程
首先我们肯定不能随便找一些
首先 FFT 的核心思想是分治,对于一个多项式
但这里有个问题,
取
其中第二行是因为
然后就能用
非递归写法
容易发现这玩意常数很大(因为每次需要把序列按奇偶性分开,可能需要开一些辅助数组,内存访问就不太连续),而这种
容易发现,如果把它看做 zkw 的结构,我们只需要求出所有
以
a[0] a[1] a[2] a[3] a[4] a[5] a[6] a[7]
a[0] a[2] a[4] a[6] | a[1] a[3] a[5] a[7]
a[0] a[4] | a[2] a[6] | a[1] a[5] | a[3] a[7]
a[0] | a[4] | a[2] | a[6] | a[1] | a[5] | a[3] | a[7]
其实已经能发现一些端倪了,第
容易发现这些东西的二进制表示翻转后就恰好是
于是就能得到非递归写法了。
代码等会和后面的整合一下。
逆 FFT(IFFT)
根据
那么就能写成矩阵
然后用神秘方法就能得到
直接说结论 因为我不会证
以及容易发现,
模板题参考代码
#include<bits/stdc++.h>
#define forup(i,s,e) for(int i=(s),E123123123=(e);i<=E123123123;++i)
#define fordown(i,s,e) for(int i=(s),E123123123=(e);i>=E123123123;--i)
#define mem(a,b) memset(a,b,sizeof(a))
#ifdef DEBUG
#define msg(args...) fprintf(stderr,args)
#else
#define msg(...) void();
#endif
using namespace std;
using i64=long long;
using ld=long double;
#define gc getchar()
int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=1<<21;
const ld pi=acosl(-1),eps=1e-7;
int n,m,p,w;
struct Complex{
ld x,y;
Complex(ld _x=0,ld _y=0):x(_x),y(_y){}
Complex operator +(const Complex &r){return Complex{x+r.x,y+r.y};}
Complex operator -(const Complex &r){return Complex{x-r.x,y-r.y};}
Complex operator *(const Complex &r){return Complex{x*r.x-y*r.y,x*r.y+y*r.x};}
};
Complex f[N],g[N];
int rev[N];
void trans(Complex *f,int type){
forup(i,0,p-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int len=1;len<p;len<<=1){
Complex wn(cosl(pi/len),type*sinl(pi/len));
for(int i=0;i<p;i+=(len<<1)){
Complex nw(1,0);
forup(j,0,len-1){
Complex x=f[i+j],y=nw*f[i+len+j];
f[i+j]=x+y;
f[i+len+j]=x-y;
nw=nw*wn;
}
}
}
if(type==-1) forup(i,0,p-1) f[i].x/=p;
}
signed main(){
n=read();m=read();
++n;++m;
forup(i,0,n-1) f[i].x=read();
forup(i,0,m-1) g[i].x=read();
p=1;w=0;
while(p<n+m) p<<=1,++w;
forup(i,0,p-1){
rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));
}
trans(f,1);
trans(g,1);
forup(i,0,p-1){
f[i]=f[i]*g[i];
}
trans(f,-1);
forup(i,0,n+m-2){
if(fabs(f[i].x<eps)){
printf("0 ");
}else{
printf("%.0Lf ",f[i].x);
}
}
}
NTT
前置知识:数论基础。
事实上如果你真的写了 FFT 模板题就会发现有一堆精度问题,而且假如用来卷积这玩意不能取模。
于是就有人发明了没有精度问题(但是需要取模)的类似算法:快速数论变换(Number Theoretic Transform,NTT)。
很多 OIer 在第一次做需要取模的计数题时都会疑惑“为什么模数是
- 它是质数,剩余系对乘法是封闭的。
- 它有原根,其中一个是
。
下文默认在模
原根(简要理解)就是说一个
考虑 FFT 中为什么要用
记原根为
然后就能用这玩意做类似于 FFT 的操作了。
注意到一个问题,刚刚钦定了
哎,您猜怎么着☝️🤓,
注意到一点区别,从
注意到
参考代码
其它地方差不多,inv3
是
void trans(int *f,int type){
forup(i,0,p-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int len=1;len<p;len<<=1){
int wn=ksm(type==1?3:inv3,(mod-1)/(len<<1));
for(int i=0;i<p;i+=(len<<1)){
int nw=1;
forup(j,0,len-1){
int x=f[i+j],y=1ll*nw*f[i+len+j]%mod;
f[i+j]=(x+y)%mod;
f[i+len+j]=(x-y+mod)%mod;
nw=1ll*nw*wn%mod;
}
}
}
if(type==-1){
int inv=ksm(p,mod-2);
forup(i,0,p-1) f[i]=1ll*f[i]*inv%mod;
}
}
任意模数 NTT
感觉没什么用就先不写了,留坑待补。
分治 FFT/NTT(半在线卷积)
还有另一种算法也叫分治 NTT,把若干个多项式乘在一起时每次把最短的两个拉出来 NTT,那个很简单就不说了。
有些时候我们需要计算形如下式的函数:
其中
这时候就不能直接将
这时候可以类似 cdq 分治优化 DP,对序列进行分治。按分治树的中序遍历,每次统计左边对右边的贡献。
具体可以看代码。
模板题参考代码
#include<bits/stdc++.h>
#define forup(i,s,e) for(int i=(s),E123123123=(e);i<=E123123123;++i)
#define fordown(i,s,e) for(int i=(s),E123123123=(e);i>=E123123123;--i)
#define mem(a,b) memset(a,b,sizeof(a))
#ifdef DEBUG
#define msg(args...) fprintf(stderr,args)
#else
#define msg(...) void()
#endif
using namespace std;
using i64=long long;
using pii=pair<int,int>;
#define fi first
#define se second
#define mkp make_pair
#define gc getchar()
int read(){
int x=0,f=1;char c;
while(!isdigit(c=gc)) if(c=='-') f=-1;
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc;}
return x*f;
}
#undef gc
const int N=1<<18|10,mod=998244353,inv3=332748118;
int ksm(int a,int b){
int c=1;
while(b){
if(b&1) c=1ll*a*c%mod;
a=1ll*a*a%mod;
b>>=1;
}
return c;
}
int n,m,g[N],f[N],h[N];
int rev[N];
void NTT(int *f,int n,int type){
int w=31^__builtin_clz(n);
forup(i,0,n-1){
rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
for(int len=1;len<n;len<<=1){
int wn=ksm(type==1?3:inv3,(mod-1)/(len<<1));
for(int i=0;i<n;i+=(len<<1)){
int nw=1;
forup(j,0,len-1){
int x=f[i+j],y=1ll*nw*f[i+len+j]%mod;
f[i+j]=(x+y)%mod;
f[i+len+j]=(x+mod-y)%mod;
nw=1ll*nw*wn%mod;
}
}
}
if(type==-1){
int inv=ksm(n,mod-2);
forup(i,0,n-1) f[i]=1ll*f[i]*inv%mod;
}
}
int a[N],b[N],c[N];
void solve(int l,int r){
if(l==r) return;
int mid=(l+r)>>1;
solve(l,mid);
int len=(r-l+1)<<1;
forup(i,0,len){
a[i]=b[i]=c[i]=0;
}
forup(i,l,r){
if(i<=mid) a[i-l]=f[i];
b[i-l]=g[i-l];
}
NTT(a,len,1);NTT(b,len,1);
forup(i,0,len-1){
c[i]=1ll*a[i]*b[i]%mod;
}
NTT(c,len,-1);
forup(i,mid+1,r){
(f[i]+=c[i-l])%=mod;
}
solve(mid+1,r);
}
signed main(){
n=read();
forup(i,1,n-1){
g[i]=read();
}
m=n;
n=1;
while(n<m){n<<=1;}
f[0]=1;
solve(0,n-1);
forup(i,0,m-1){
printf("%d ",f[i]);
}puts("");
}