Skip to content

FFT(快速傅里叶变换)与 NTT(快速数论变换)

多项式全家桶其之一。

参考资料1 参考资料2

前置知识:

复数(请移步高中必修二或 OI wiki)。

初中范围的多项式。

符号与约定

为了好看(以及好写),用 a0,a1,a2,an 表示一个有序 n 元组,也用这个表示数列(某种意义上,数列也是有序 n 元组吧)。

ab 表示存在整数 k 使得 ak=b

单位根

前置知识:复数。

众所周知,任意一元 n 次方程必有 n 个根。只是初中范围(以及高中“复数”这一章以外)内只研究实根,而在遇到 a 之类的东西都会直接省去。实际上这是能产生复数根的。

考虑 xn=1 这个方程。由于复数的运算法则(模长相乘,辐角相加),显然这 n 个解恰好把复平面单位圆均分成 n 份(即辐角分别为 2kπn,k[0,n1])。我们称 xn=1 的根为 n 次单位根,记除去 1 之外辐角最小的 n 次单位根为 ωn\omega_n),显然其它单位根就是 ωni

单位根有三个比较显然的引理:

  1. 消去引理:ωknkm=ωnm,根据单位根几何意义显然。
  2. 折半引理:(ωnk+n2)2=(ωnk)2=ωn2k=ωn2k,这个只在 n 为偶数时成立,带消去引理即可。
  3. 求和引理:i=0n1(ωnk)i=0,这玩意显然就等于 i=0n1ωni,是一个等比数列求和,答案即为 ωnn1ωn11=11ωn11=0

另外根据三角函数基础,显然 ωn=cos(2πn)+sin(2πn)i

多项式

前置知识:初中范围的多项式。

此处只讨论一元 n 次多项式。即 F(x)=i=0naixi

多项式有两种表示方法:

  1. 系数表示:即列举出每一项的系数 a0,a1,an,其中 aixi 项的系数。
  2. 点值表示:根据经典结论,n+1 个互不相同的点就能确定一个一元 n 次多项式,那么可以用 (x0,F(x0)),(x1,F(x1)),(x2,F(x2)),(xn+1,F(xn+1)) 唯一确定一个多项式。并且这个神秘性质在定义域取到复数域,甚至是某个质数的剩余系时仍然成立,我证不来。

FFT

相关概念

  • 离散傅里叶变换(Discrete Fourier Transform,DFT),将系数表示转化为某个点值表示的方法(实际定义比较复杂,涉及到一些信号学知识,此处可以这样理解)。
  • 快速傅里叶变换(即 快速(离散)傅里叶变换,Fast (Discrete) Fourier Transform,FFT),复杂度较低的 DFT。

下文的 n 均指多项式次数 +1(因为 0 次项也要占一个位置,不这样搞会导致有一堆奇怪的 +1),称这个数为多项式“项数”。

算法引入

假如我们要算多项式的积(H(x)=G(x)×F(x)),直接用系数表示做需要 O(n2) 做。

但是假如我们知道了 GF 的点值表示,根据 H(x)=G(x)×F(x),就能 O(n) 求出 H 的点值表示(容易发现 G,F,H 的项数不同,其中 Fn 项,Gm 项,Hn+m1 项。一个办法是把三者都看成 n+m 项式)。如果我们还有能从点值表示推出系数表示的方法,就能得到 H(x) 的系数表示了。

并且容易发现多项式乘法就是对系数的卷积,那么这玩意还能用来快速求卷积。

算法流程

首先我们肯定不能随便找一些 xi,这样显然不好优化。FFT 干的事情是先找到最小的 k 满足 2kn,把 n 项多项式视为一个 2k 项多项式,然后用 ω2ki 作为 xi 来表示点值。下文会叙述为什么要这样搞。

首先 FFT 的核心思想是分治,对于一个多项式 F 的系数表示 A=a0,a1,a2,an1n=2k),可以把它按奇偶性分成 A0=a0,a2,a4,an2A1=a1,a3,a5,an1,设这两个系数表示对应的多项式分别为 F0,F1,容易发现 F(x)=F0(x2)+xF1(x2)(显然)。于是就能想到用 F0,F1 的点值表示来算出 F 的点值表示。

但这里有个问题,F0,F1 的点值表示都只有 n2 个点,这样似乎无论怎么搞都只能得到 Fn2 个点。这时候你就知道为什么要用单位根作为 xi 了。

k 为某个 [0,n2) 之中的整数,记 k=k+n2,容易发现:

F(ωnk)=F0((ωnk)2)+ωnkF1((ωnk)2)=F0(ωn2k)+ωnkF1(ωn2k)F(ωnk)=F0((ωnk)2)+ωnkF1((ωnk)2)=F0(ωn2k)ωnkF1(ωn2k)

其中第二行是因为 ωnk+n2=ωnk

然后就能用 F0,F1 的点值表示推出 F 的点值表示了。对于 n=1 的边界条件,显然 F(ω10)=F(1)=a0x=a0,于是就有一个简单的递归写法了。需要封装复数的加减乘法。

非递归写法

容易发现这玩意常数很大(因为每次需要把序列按奇偶性分开,可能需要开一些辅助数组,内存访问就不太连续),而这种 2k 每次都能恰好地把序列分成两半,这和 zkw 线段树十分类似,可以考虑是否存在非递归写法。

容易发现,如果把它看做 zkw 的结构,我们只需要求出所有 n=1 在递归树上对应的下标,其它就能直接倒序循环求出。

n=8 为例,考虑递归过程:

a[0]   a[1]   a[2]   a[3]   a[4]   a[5]   a[6]   a[7] 
a[0]   a[2]   a[4]   a[6] | a[1]   a[3]   a[5]   a[7] 
a[0]   a[4] | a[2]   a[6] | a[1]   a[5] | a[3]   a[7] 
a[0] | a[4] | a[2] | a[6] | a[1] | a[5] | a[3] | a[7] 

其实已经能发现一些端倪了,第 i 层是按照二进制从下到上第 i 为分开的,那么从二进制方向考虑每个点的下标:

0   4   2   6   1   5   3   7
000 100 010 110 001 101 011 111

容易发现这些东西的二进制表示翻转后就恰好是 07,从每次“以当前最低位为依据分开”也能推出来。于是直接以这样的“逆二进制序”处理即可。

于是就能得到非递归写法了。

代码等会和后面的整合一下。

逆 FFT(IFFT)

根据 F(ωni)=j=0n1ajωnij,容易发现这玩意是个线性变换。

那么就能写成矩阵 y=a×Vn,其中 (Vn)i,j=ωnij

然后用神秘方法就能得到 Vn1,使得 a=y×Vn1

直接说结论 (Vn1)i,j=ωnijn因为我不会证

以及容易发现,ωn1=cos(2πn)+sin(2πn)i(考虑 ωn1×ωn=1,模长相乘辐角相加)。然后其余部分和 FFT 完全相同,最后除以 n(注意这个 “n” 是前文所说 2k)即可。

模板题参考代码
#include<bits/stdc++.h>
#define forup(i,s,e) for(int i=(s),E123123123=(e);i<=E123123123;++i)
#define fordown(i,s,e) for(int i=(s),E123123123=(e);i>=E123123123;--i)
#define mem(a,b) memset(a,b,sizeof(a))
#ifdef DEBUG
#define msg(args...) fprintf(stderr,args)
#else
#define msg(...) void();
#endif
using namespace std;
using i64=long long;
using ld=long double;
#define gc getchar()
int read(){
    int x=0,f=1;char c;
    while(!isdigit(c=gc)) if(c=='-') f=-1;
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc;}
    return x*f;
}
#undef gc
const int N=1<<21;
const ld pi=acosl(-1),eps=1e-7;
int n,m,p,w;
struct Complex{
    ld x,y;
    Complex(ld _x=0,ld _y=0):x(_x),y(_y){}
    Complex operator +(const Complex &r){return Complex{x+r.x,y+r.y};}
    Complex operator -(const Complex &r){return Complex{x-r.x,y-r.y};}
    Complex operator *(const Complex &r){return Complex{x*r.x-y*r.y,x*r.y+y*r.x};}
};
Complex f[N],g[N];
int rev[N];
void trans(Complex *f,int type){
    forup(i,0,p-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int len=1;len<p;len<<=1){
        Complex wn(cosl(pi/len),type*sinl(pi/len));
        for(int i=0;i<p;i+=(len<<1)){
            Complex nw(1,0);
            forup(j,0,len-1){
                Complex x=f[i+j],y=nw*f[i+len+j];
                f[i+j]=x+y;
                f[i+len+j]=x-y;
                nw=nw*wn;
            }
        }
    }
    if(type==-1) forup(i,0,p-1) f[i].x/=p;
}
signed main(){
    n=read();m=read();
    ++n;++m;
    forup(i,0,n-1) f[i].x=read();
    forup(i,0,m-1) g[i].x=read();
    p=1;w=0;
    while(p<n+m) p<<=1,++w;
    forup(i,0,p-1){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));
    }
    trans(f,1);
    trans(g,1);
    forup(i,0,p-1){
        f[i]=f[i]*g[i];
    }
    trans(f,-1);
    forup(i,0,n+m-2){
        if(fabs(f[i].x<eps)){
            printf("0 ");
        }else{
            printf("%.0Lf ",f[i].x);
        }
    }
}

NTT

前置知识:数论基础。

事实上如果你真的写了 FFT 模板题就会发现有一堆精度问题,而且假如用来卷积这玩意不能取模。

于是就有人发明了没有精度问题(但是需要取模)的类似算法:快速数论变换(Number Theoretic Transform,NTT)。

很多 OIer 在第一次做需要取模的计数题时都会疑惑“为什么模数是 998244353”,就是因为有这个算法的存在。

998244353 有非常好的性质:

  • 它是质数,剩余系对乘法是封闭的。
  • φ(998244353)=998244352=223×7×17
  • 它有原根,其中一个是 3

下文默认在模 998244353 意义下进行。

原根(简要理解)就是说一个 p 的简化剩余系内的数 a,若满足 0n<φ(p),所有 anmodp 互不相同(或者 0<n<φ(p),an1(modp),这两个定义显然等价),就称 ap 的原根。

考虑 FFT 中为什么要用 ωni 作为点值表示的下标(下文记 xi 表示点值表示的第 i 个下标)。就是因为它具有消去引理(ωknkm=ωnm)和折半引理((ωnk+n2)2=(ωnk)2=ωn2k=ωn2k),那么原根有没有类似的性质呢?

记原根为 g(当 p=998244353 时,g=3),又记 gn=gp1n(钦定 np1)。于是有:

  • gknkm=gkm(p1)kn=gm(p1)n=gnm
  • (gnk+n2)2=gn2k+n=g(2k+n)(p1)n=g2k(p1)n+gp1=g2k(p1)n=gn2k=gn2k

然后就能用这玩意做类似于 FFT 的操作了。

注意到一个问题,刚刚钦定了 np1

哎,您猜怎么着☝️🤓,998244352223 的因子(8.3×106),可以把 n 调整为 2 的幂次就能做了。

注意到一点区别,从 F0,F1 推到 F 的式子(k[0,n2),k=k+n2):

F(gnk)=F0((gnk)2)+gnkF1((gnk)2)=F0(gn2k)+gnkF1(gn2k)F(gnk)=F0((gnk)2)+gnkF1((gnk)2)=F0(gn2k)+gp12gnkF1(gn2k)

注意到 gp12=gp1=1,根据二次探测引理,gp12=±1。但是若这玩意等于 1 那么 g 就不是原根了(和定义不符),所以这东西必定等于 1(即 p1)。

参考代码

其它地方差不多,inv33 的逆元。

void trans(int *f,int type){
    forup(i,0,p-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int len=1;len<p;len<<=1){
        int wn=ksm(type==1?3:inv3,(mod-1)/(len<<1));
        for(int i=0;i<p;i+=(len<<1)){
            int nw=1;
            forup(j,0,len-1){
                int x=f[i+j],y=1ll*nw*f[i+len+j]%mod;
                f[i+j]=(x+y)%mod;
                f[i+len+j]=(x-y+mod)%mod;
                nw=1ll*nw*wn%mod;
            }
        }
    }
    if(type==-1){
        int inv=ksm(p,mod-2);
        forup(i,0,p-1) f[i]=1ll*f[i]*inv%mod;
    }
}

任意模数 NTT

感觉没什么用就先不写了,留坑待补。

分治 FFT/NTT(半在线卷积)

还有另一种算法也叫分治 NTT,把若干个多项式乘在一起时每次把最短的两个拉出来 NTT,那个很简单就不说了。

有些时候我们需要计算形如下式的函数:

f(n)={c,n=0j=1if(ij)g(j),n1

其中 c 是常数。

这时候就不能直接将 f,g 进行卷积了,因为 f 还没有求出来。

这时候可以类似 cdq 分治优化 DP,对序列进行分治。按分治树的中序遍历,每次统计左边对右边的贡献。

具体可以看代码。

模板题参考代码
#include<bits/stdc++.h>
#define forup(i,s,e) for(int i=(s),E123123123=(e);i<=E123123123;++i)
#define fordown(i,s,e) for(int i=(s),E123123123=(e);i>=E123123123;--i)
#define mem(a,b) memset(a,b,sizeof(a))
#ifdef DEBUG
#define msg(args...) fprintf(stderr,args)
#else
#define msg(...) void()
#endif
using namespace std;
using i64=long long;
using pii=pair<int,int>;
#define fi first
#define se second
#define mkp make_pair
#define gc getchar()
int read(){
    int x=0,f=1;char c;
    while(!isdigit(c=gc)) if(c=='-') f=-1;
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc;}
    return x*f;
}
#undef gc
const int N=1<<18|10,mod=998244353,inv3=332748118;
int ksm(int a,int b){
    int c=1;
    while(b){
        if(b&1) c=1ll*a*c%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return c;
}
int n,m,g[N],f[N],h[N];
int rev[N];
void NTT(int *f,int n,int type){
    int w=31^__builtin_clz(n);
    forup(i,0,n-1){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));
        if(i<rev[i]) swap(f[i],f[rev[i]]);
    }
    for(int len=1;len<n;len<<=1){
        int wn=ksm(type==1?3:inv3,(mod-1)/(len<<1));
        for(int i=0;i<n;i+=(len<<1)){
            int nw=1;
            forup(j,0,len-1){
                int x=f[i+j],y=1ll*nw*f[i+len+j]%mod;
                f[i+j]=(x+y)%mod;
                f[i+len+j]=(x+mod-y)%mod;
                nw=1ll*nw*wn%mod;
            }
        }
    }
    if(type==-1){
        int inv=ksm(n,mod-2);
        forup(i,0,n-1) f[i]=1ll*f[i]*inv%mod;
    }
}
int a[N],b[N],c[N];
void solve(int l,int r){
    if(l==r) return;
    int mid=(l+r)>>1;
    solve(l,mid);
    int len=(r-l+1)<<1;
    forup(i,0,len){
        a[i]=b[i]=c[i]=0;
    }
    forup(i,l,r){
        if(i<=mid) a[i-l]=f[i];
        b[i-l]=g[i-l];
    }
    NTT(a,len,1);NTT(b,len,1);
    forup(i,0,len-1){
        c[i]=1ll*a[i]*b[i]%mod;
    }
    NTT(c,len,-1);
    forup(i,mid+1,r){
        (f[i]+=c[i-l])%=mod;
    }
    solve(mid+1,r);
}
signed main(){
    n=read();
    forup(i,1,n-1){
        g[i]=read();
    }
    m=n;
    n=1;
    while(n<m){n<<=1;}
    f[0]=1;
    solve(0,n-1);
    forup(i,0,m-1){
        printf("%d ",f[i]);
    }puts("");
}

Comments