呜呼呼呼呼呼呼呼
FFT
只说FFT在优化多项式乘法的应用。
首先我们有两个系数表达的多项式f,g
f(x)=a1xk+a2xk−1....akx+ak+1
g(x)=a1xp+a2xp−1....apx+ap+1
现在我们要求得F=f∗g
显然我们可以爆算,时间复杂度O(n2)
然而对n较大的时候,上面显然超时,所以我们就得学习一个更优秀的解决问题的办法:FFT
FFT的具体思路:先把系数表达转换成点值表达,然后可以O(n)地把点值表示相乘,最后通过IDFT转化回系数表达
前置芝士
下面这一部分其实可以跳过,不会讲FFT具体操作
点值表示
就是用n+1个点来表示一个多项式
至于为啥是n+1,我们有n+1个未知数,为了得到唯一解必须要这么多。
虚数
形如a+bi的数,其中a,b是实数,且b=0,i2=−1。
在FFT中我们需要这三个运算
假设x,y是虚数其中x=a+bi,y=c+di
x+y=(a+c)+(b+d)i
x−y=(a−c)+(b−d)i
x∗y=(ac−bd)+(ad+bc)i
欧拉公式
eix=cosx+isinx
单位负根
就是xn=1的解
就是在坐标系上把r=1的圆均分成n份,比如这是当n=4时候的情况

他有啥性质
wnn=1
wnk=w2n2k
w2nk+n=−w2nk
具体操作
系数表示转化成点值表示
这也叫作DFT
现在我们又一个多项式
f(x)=a7x7+a6x6+a5x5+a4x4+a3x3+a2x2+a1x+a0
如果我们硬求n+1个点出来,复杂度是O(n2)的
这时候就有一个很巧妙的东西了,我们把奇数项和偶数项单独提出来
f(x)=a6x6+a4x4+a2x2+a0+a7x7+a5x5+a3x3+a1x
f(x)=a6x6+a4x4+a2x2+a0+x(a7x6+a5x4+a3x2+a1)
然后令g,h
g(x)=a6x3+a4x2+a2x1+a0
h(x)=a7x3+a5x2+a3x1+a1
那么有f(x)=g(x2)+xh(x2)
然而这还不够,根据单位复根
w2nk=−w2nk+n
那么有f(w2nk+n)=g((w2nk)2)−w2nkh((w2nk)2
因此我们每次可以只算前一半的值,然后可以根据这个公式得到后一半的值
通过分治,可以时间复杂度变为O(nlogn)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| struct comp{ double a,b; comp(double aa=0,double bb=0){a=aa; b=bb;} }a[maxn],b[maxn];
comp operator +(const comp a,const comp b){return comp(a.a+b.a,a.b+b.b);} comp operator -(const comp a,const comp b){return comp(a.a-b.a,a.b-b.b);} comp operator *(const comp a,const comp b){return comp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
void fft(int len,comp *a) { if(len==1) return ; comp a1[(len>>1)+1],a2[(len>>1)+1]; int i; for(i=0;i<=len;i+=2) a1[i>>1]=a[i],a2[i>>1]=a[i+1]; fft(len/2,a1); fft(len/2,a2); comp w=comp(1,0); comp k=comp(cos(pi*2.0/len),sin(pi*2.0/len)); for(i=0;i<len/2;i++,w=w*k) a[i]=a1[i]+w*a2[i],a[i+len/2]=a1[i]-w*a2[i]; }
|
点值表示转化成系数表示
这也叫做IDFT
现在我们算出来很多y的值,如果单纯来解方程的话,时间复杂度n3原地爆炸
这里先说公式
离散傅里叶变换
bk=i=0∑n−1ai∗wnki
离散傅里叶逆变换
ak=n1i=0∑n−1bi∗wn−ki
证明:
首先我们单纯地把bi带入,同时约去n1
\begin{eqnarray}
\sum_{i=0}^{n-1}\limits b_iw_n^{-ki} &=& \sum_{i=0}^{n-1}\limits w_{n}^{-ki}
\sum_{j=0}^{n-1}\limits w_{n}^{ij}a_j\\
& =&
\sum^{n-1}_{j=0}a_j\sum^{n-1}_{i=0}w_n^{-ki}*w_n^{ij}
\\
& =&
\sum^{n-1}_{j=0}a_j\sum^{n-1}_{i=0} w_n^{i(j-k)}
\end{eqnarray}
考虑i=0∑n−1wni(j−k)
如果j=k 那么显然的原式=n
否则这就是个公比为wn(j−k)的等比数列,考虑求和
那么就有s=a1∗1−q1−qn显然a1=1
说明
(wnj−k)n=(wnn)j−k=1=qn
那么s=0
原式就变成了
\begin{eqnarray}
\frac{1}{n}\sum_{i=0}^{n-1}\limits b_iw_n^{-ki}&=&\frac{1}{n}\sum^{n-1}\limits_{j=0}a_j\sum^{n-1}_{i=0} \limits w_n^{i(j-k)}
\\ &=& \frac{1}{n} \sum^{n-1}_{j=0}a_j\sum^{n-1}_{i=0} w_n^{i(j-k)}
\\ &=&
\frac{1}{n}*n*a_k\\&=&a_k
\end{eqnarray}
因此得证
这东西有啥用
我们发现除了分母又一个n1和wn上面有个负号外其他的都和DFT一样,因此我们考虑用上面那个代码,传一个−1的参数,然后做完了最后除n就是我们想要的系数了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| void fft(int len,comp *a,int t) { if(len==1) return ; comp a1[(len>>1)+1],a2[(len>>1)+1]; int i; for(i=0;i<=len;i+=2) a1[i>>1]=a[i],a2[i>>1]=a[i+1]; fft(len/2,a1,t); fft(len/2,a2,t); comp w=comp(1,0); comp k=comp(cos(pi*2.0/len),t*sin(pi*2.0/len)); for(i=0;i<len/2;i++,w=w*k) a[i]=a1[i]+w*a2[i],a[i+len/2]=a1[i]-w*a2[i]; }
for(i=0;i<=n+m;i++) cout<<int(a[i].a/lmt+0.49)<<" ";
|
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| #include<stdio.h> #include<iostream> #include<cmath> #include<cstring> #include<queue> #include<stack> #include<vector> #include<set> #include<map> #include<algorithm>
using namespace std;
const int maxn=4201000; const double pi=3.1415926535897932; int n,m,lmt=1;
struct comp{ double a,b; comp(double sa=0,double sb=0){a=sa; b=sb;} }a[maxn],b[maxn];
comp operator +(const comp a,const comp b){return comp(a.a+b.a,a.b+b.b);} comp operator -(const comp a,const comp b){return comp(a.a-b.a,a.b-b.b);} comp operator *(const comp a,const comp b){return comp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
void fft(int len,comp *a,int t) { if(len==1) return ; comp a1[(len>>1)+1],a2[(len>>1)+1]; int i; for(i=0;i<=len;i+=2) a1[i>>1]=a[i],a2[i>>1]=a[i+1]; fft(len/2,a1,t); fft(len/2,a2,t); comp w=comp(1,0); comp k=comp(cos(pi*2.0/len),t*sin(pi*2.0/len)); for(i=0;i<len/2;i++,w=w*k) a[i]=a1[i]+w*a2[i],a[i+len/2]=a1[i]-w*a2[i]; }
int main() { ios::sync_with_stdio(false); register int i,j; cin>>n>>m; for(i=0;i<=n;i++) cin>>a[i].a; for(i=0;i<=m;i++) cin>>b[i].a; while(lmt<=(n+m)) lmt*=2 fft(lmt,a,1); fft(lmt,b,1); for(i=0;i<=lmt;i++) a[i]=a[i]*b[i]; fft(lmt,a,-1); for(i=0;i<=n+m;i++) cout<<int(a[i].a/lmt+0.49)<<" "; cout<<endl; return 0; }
|
优化
首先我们手工推一下FFT过程
1 2 3
| 0 1 2 3 4 5 6 7 0 2 4 6 1 3 5 7 0 4 2 6 1 5 3 7
|
观察到每行的差值不一样(废话)
然后第一行的二进制是这样的
1 2 3
| 000 001 010 011 100 101 110 111 000 010 100 110 001 011 101 111 000 100 010 110 001 101 001 111
|
发现第一层和最后一层就像是翻转了,实际上就是这样,证明没那个必要,那我们就可以在FFT之前得到这个数组,然后FFT就不需要递归进行了
因为它的分治过程
1 2 3
| 0 4 2 6 1 5 3 7 0 2 4 6 1 3 5 7 0 4 2 6 1 5 3 7
|
每一层FFT的时候,都是左边*右边,因此结果是一样的
所以不需要复制,我们考虑一下怎么构造这个数组
由于是从小到大枚举,我们理所应当地认为i/2已经处理好了(除了0,但是不用管,因此我们得到了i/2的翻转值,由于我们,这里是整数除法),然后还需要向右移一位,一是因为那一位没有用,而是我们还有一最后一个二进制位没有翻转
1
| for(i=0;i<lmt;i++) pla[i]=(pla[i>>1]>>1)|((i&1)?lmt>>1:0);
|
还有一些其他的优化,但是提升远没有这个明显(除了三次变两次优化)
整个代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| #include<iostream> #include<cmath> #include<stdio.h> #include<cstring> #include<queue> #include<stack> #include<vector> #include<set> #include<map> #include<algorithm>
using namespace std;
const int maxn=2100010; const double pi=3.1415926535897; int n,m,lmt; int pla[maxn];
struct comp{ double x,y; comp(double xx=0,double yy=0) {x=xx;y=yy;} }a[maxn],b[maxn];
comp operator +(const comp &a,const comp &b){return comp(a.x+b.x,a.y+b.y);} comp operator -(const comp &a,const comp &b){return comp(a.x-b.x,a.y-b.y);} comp operator *(const comp &a,const comp &b){return comp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
void fft(comp *a,int t) { int i,j,len,p; for(i=0;i<lmt;i++) if(i<pla[i]) swap(a[i],a[pla[i]]); for(p=2;p<=lmt;p<<=1) { len=p>>1; comp k=comp(cos(2.0*pi/p),1.0*t*sin(2.0*pi/p)); for(i=0;i<lmt;i+=p) { comp w=comp(1,0); for(j=i;j<i+len;j++) { comp looker=a[j+len]*w; a[j+len]=a[j]-looker; a[j]=a[j]+looker; w=w*k; } } } }
int main() { ios::sync_with_stdio(false); register int i,j; cin>>n>>m; for(i=0;i<=n;i++) cin>>a[i].x; for(i=0;i<=m;i++) cin>>b[i].x; lmt=1; while(lmt<=(n+m)) lmt<<=1; for(i=0;i<lmt;i++) pla[i]=(pla[i>>1]>>1)|((i&1)?lmt>>1:0); fft(a,1); fft(b,1); for(i=0;i<lmt;i++) a[i]=a[i]*b[i]; fft(a,-1); for(i=0;i<=n+m;i++) cout<<int((a[i].x)/lmt+0.49)<<" "; cout<<endl; return 0; }
|
NTT
前置芝士
原根
原根有这个东西
(a,m)=1使$a^l \equiv 1(\mod m) 成立的最小的l就是a关于模m的阶,记做ord_ma$
如果∃(g,m)=1并且ordmg=ϕ(m)那么就说明g是m的一个原根
此时{g,g2,...,gϕ(m)}构成了一个模m的既约剩余系
那他有啥意义,当我们取一个质数的时候ϕ(m)=m−1
这个时候就具有了单位复根的性质
而且空间相较于FFT少了一半,同时也变成了整数乘法不会损失
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
| #include<iostream> #include<cmath> #include<stdio.h> #include<cstring> #include<queue> #include<stack> #include<vector> #include<set> #include<map> #include<algorithm>
#define int long long
using namespace std;
const int maxn=2100100; const int modp=998244353;
int n,m,a[maxn],b[maxn]; int pla[maxn]; int lmt=1;
inline int ksm(int xx,int y) { long long x=1ll*xx; long long ans=1ll; while(y) { if(y&1) ans=(ans*x)%modp; x=(x*x)%modp; y>>=1; } return ans%modp; }
inline void ntt(int *a,int t) { int i,len,p,j; for(i=0;i<lmt;i++) if(i<pla[i]) swap(a[i],a[pla[i]]); for(p=2;p<=lmt;p<<=1) { len=p>>1; int k=ksm(t,(modp-1)/p); for(i=0;i<lmt;i+=p) { int w=1; for(j=i;j<i+len;j++) { int looker=(w*a[j+len])%modp; a[j+len]=(a[j]-looker+modp)%modp; a[j]=(a[j]+looker)%modp; w=(w*k)%modp; } } } }
signed main() { ios::sync_with_stdio(false); register int i,j; cin>>n>>m; for(i=0;i<=n;i++) cin>>a[i]; for(i=0;i<=m;i++) cin>>b[i]; while(lmt<=(n+m)) lmt<<=1; for(i=0;i<lmt;i++) pla[i]=(pla[i>>1]>>1)|((i&1)?lmt>>1:0); ntt(a,3); ntt(b,3); for(i=0;i<lmt;i++) a[i]=(1ll*a[i]*b[i])%modp; ntt(a,ksm(3,modp-2)); lmt=ksm(lmt,modp-2); for(i=0;i<=n+m;i++) cout<<(a[i]*lmt)%modp<<" "; cout<<endl; return 0; }
|
优化
预处理原根和原根的逆元的1<<p次方