FFT--快速傅里叶变换入门

FFT—快速傅里叶变换入门

快速傅里叶变换(英语:Fast Fourier Transform, FFT),是快速计算序列的离散傅里叶变换(DFT)或其逆变换的方法[1]傅里叶分析将信号从原始域(通常是时间或空间)转换到频域的表示或者逆过来转换。FFT会通过把DFT矩阵分解稀疏(大多为零)因子之积来快速计算此类变换。[2] 因此,它能够将计算DFT的复杂度从只用DFT定义计算需要的 $O(n^2)$降低到$O(n \log n)$,其中$n$为数据大小

---[维基百科-快速傅里叶变换](https://zh.wikipedia.org/wiki/快速傅里叶变换)

Preface

如何计算两个多项式$\sum_{i=0}^n a_i x^i$,$\sum_{i=0}^n b_i x^i$相乘后新的多项式?

非常自然的,我们可以将一个多项式中的每个系数与另一个多项式的各个系数逐个相乘后相加后得到新的多项式。这样做时间复杂度为$O(n^2)$

但是如果我们会多项式插值的话还有另一种方法,我们先用$n+1$个不同的点表示出一个多项式(根据高斯消元法$n+1$个点可以确定一个唯一的$n$次多项式)。假设这里的两个多项式分别过$(x_i,{y_1}_i)$和$(x_i,{y_2}_i)$,那么它们相乘后的多项式就过$(x_i,{y_1}_i \times {y_2}_i)$这个点。但是如果我们想使用插值将点值转化为系数的话,由于相乘最高次数已经达到了$2 \times n$,所以在一开始也得求出更多的点值后相乘并插值。求出这么多个点值的时间复杂度是$O(n^2)$,同样的插值的话我们使用拉格朗日插值法的话依旧需要$O(n^2)$的时间复杂度

所以这种方法的时间复杂度依旧为$O(n^2)$。但是我们可以通过选择一些特定的点和精妙的算法,将求点值过程和插值过程的时间复杂度降为$O(n\log n)$,总的时间复杂度也降为$O(n \log n)$.

学习资料

太懒了不想再码一遍长长的前置知识和算法过程,就列出我学习FFT时认为非常有用的资料:

  1. Cormen, T. H., Leiserson, C. E., Rivest, R. L., & Stein, C. (2009). Introduction to algorithms. MIT press.

  2. rvalue(2017), [学习笔记] 多项式与快速傅里叶变换(FFT)基础, retrieved from 2017/08/13, https://www.cnblogs.com/rvalue/p/7351400.html

  3. 复数学习部分:

    rvalue(2018),[学习笔记&教程] 信号, 集合, 多项式, 及各种卷积性变换 (FFT,NTT,FWT,FMT), retrieved from 2018/12/14, https://www.cnblogs.com/rvalue/p/10120174.html,

  4. 维基百科-快速傅里叶变换, https://zh.wikipedia.org/wiki/快速傅里叶变换

一些重点

这里列出一些我在自学时一开始有困惑以及认为比较重要的知识点

  • 复数部分

    我们用坐标系表示复数$w=a+bi$,横坐标表示实部,纵坐标表示虚部。使用这种表示后我们又有了关于一个复数的另外两个参数:

    • 辐角:

      原点出发连接到一个复数$w$所在的点形成的射线与$x$轴的正半轴形成的夹角$\phi$

    • 模长:

      原点出发连接到一个复数$w$所在的点形成的有向线段的长度,等于$\sqrt{a^2+b^2}$

      两个复数相乘的法则是:辐角相加,模长相乘

      至于这个法则的证明见学习资料中第三篇中的复数讲解部分,不想码了。。。

      根据这个法则后我们可以构造一个具有神奇性质的圆,这个圆以原点为圆心,单位长度$1$为半径(其实就是一个单位圆lah)

      这个圆上的点所表示出的复数,根据上文的相乘法则,它们的幂次总是也是这个单位圆上.

      由此我们引入另一个概念:$n$次单位复数根。$w^n = 1$,这个方程的复数根$w$称为$n$次单位复数根。

      显然这样的根有$n$个,它们均匀分布在单位圆上,辐角为$k \frac{2 \pi}{n}$

      再根据复数的指数表示法(欧拉公式), n次单位根$w^k_n = e^{\frac{k \times 2 \pi}{n} i}$

      根据单位圆上的几何定义和指数定义,我们可以得出一下几个性质:

    • $w^0_n = w^n_n =1, w^{\frac{2}{n}}_n =-1$(点在$(-1,0)$)

    • $w^{k+\frac{2}{n}}_n = e^{(k+\frac{2}{n})\frac{2 \pi}{n}i} = e^{k \frac{2 \pi}{n} i} \times e^{\frac{2}{n} \frac{2 \pi}{n} i}$ = $w^k_n w^{\frac{2}{n}}_n = -w^k_n$
    • 消去定理&折半定理

      • 消去定理

        根据指数定义,$w^{dk}_{dn} = w^k_n $

      • 折半定理

        $(w^{k+\frac{2}{n}}_{n})^2 = w^{2k+n}_n = w^{2k}_n w^n_n = (w^k_n)^2$

      我觉得看得懂复数部分的话这些证明都不是问题

    • 库利-图基算法

      重头戏,分治的关键

      我们将多项式次数补至$2$的次幂,对于多项式$A(x)$

      $A_{even} = a_0 + a_2 x +… +a_{n-2}x^{\frac{n}{2}-1}$

      $A_{odd} = a_1 + a_3 x +… +a_{n-1}x^{\frac{n}{2}-1}$

      显然$A(x) = A_{even}(x^2) +xA_{odd}(x^2)$

      我们这样进行分治来进行多项式求值

      $A(w^k_n) = A_{even}(w^{2k}_n) + w^k_n A_{odd} (w^{2k}_n) = A_{even}(w^{k}_{n/2}) + w^k_n A_{odd} (w^{k}_{n/2})$

      于是我们可以将$n$缩小一半分治处理,但是这样$k$是要小于$n/2$的,对于大于$n/2$的$k$我们这样处理

      $A(w^{k+n/2}_{n}) = A_{even}(w^{2k+n}_{n}) + w^{k+n/2}_n A_{odd}(w^{2k+n}_{n})=A_{even} (w^{2k}_n) - w^k_n A_{odd} (w^{2k}_n) = A_{even}(w^{k}_{n/2}) - w^k_n A_{odd} (w^{k}_{n/2})$

  • 插值

    其实公式推导我也不太懂。。。

    • 迭代版本

      算法导论上讲得很清楚,我就根据它的伪代码写的,这里不多讲了

      但是求位逆序置换时将每个数二进制展开的话太慢了,可以使用rvalue博客中的这种方法,并不难理解

      $rev[i] = (rev[i>>1]>>1)|((i \& 1)<<(bits-1))$

      bits是总的二进制位

代码

递归版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
#define ll long long
#define ri register int
#define ull unsigned long long
#define PI acos(-1)
using namespace std;
const int MAXN=4e6+5;
const double eps=1e-9;
struct Complex{
double real,imag;
Complex () {real=imag=0;}
Complex (double x,double y) {real=x,imag=y;}
}a[MAXN],b[MAXN],c[MAXN];
Complex operator + (const Complex & x,const Complex & y){
return Complex(x.real+y.real, x.imag+y.imag);
}
Complex operator - (const Complex & x,const Complex & y){
return Complex(x.real-y.real, x.imag-y.imag);
}
Complex operator * (const Complex & x,const Complex & y){
return Complex(x.real*y.real-x.imag*y.imag, x.real*y.imag+x.imag*y.real);
}
int N,M,n,nn=1;
double chart_sin[MAXN],chart_cos[MAXN];
void DFT(Complex* now,int len){
if(len==1)return ;
int nlen=len/2;
Complex* now0=new Complex[nlen];//now0[nlen],now1[nlen];
Complex* now1=new Complex[nlen];
Complex unit = Complex(chart_cos[len],chart_sin[len]);
Complex t,w = Complex(1,0);
for(int i=0;i<len;i+=2){
now0[i/2]=now[i];
now1[i/2]=now[i+1];
}
DFT(now0,nlen);
DFT(now1,nlen);
for(int i=0;i<nlen;i++){
t=w*now1[i];
now[i]=now0[i]+t;
now[i+nlen]=now0[i]-t;
w=w*unit;
}
delete[] now0;
delete[] now1;
return;
}
void IDFT(Complex* now,int len){
if(len==1)return ;
int nlen=len/2;
Complex* now0=new Complex[nlen];//now0[nlen],now1[nlen];
Complex* now1=new Complex[nlen];
Complex unit = Complex(chart_cos[len],-chart_sin[len]);
Complex t,w = Complex(1,0);
for(int i=0;i<len;i+=2){
now0[i/2]=now[i];
now1[i/2]=now[i+1];
}
IDFT(now0,nlen);
IDFT(now1,nlen);
for(int i=0;i<nlen;i++){
t=w*now1[i];
now[i]=now0[i]+t;
now[i+nlen]=now0[i]-t;
w=w*unit;
}
delete[] now0;
delete[] now1;
return;
}
int main(){
scanf("%d %d",&N,&M);
n=N+M;
for(int i=0;i<=N;i++)scanf("%lf",&a[i].real);
for(int i=0;i<=M;i++)scanf("%lf",&b[i].real);
while(nn<=n){
nn=nn<<1;
}
for(int i=2;i<=nn;i=i<<1){
chart_sin[i]=sin(2*PI/i);
chart_cos[i]=cos(2*PI/i);
}
DFT(a,nn);
DFT(b,nn);

for(int i=0;i<=nn;i++){
c[i]=a[i]*b[i];
}
IDFT(c,nn);
for(int i=0;i<=n;i++){
printf("%d ",int(c[i].real/nn+0.5));
}
return 0;
}

迭代版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
#define ll long long
#define ri register int
#define ull unsigned long long
#define PI acos(-1)
using namespace std;
const int MAXN=4e6+5;
const double eps=1e-9;
struct Complex{
double real,imag;
Complex () {real=imag=0;}
Complex (double x,double y) {real=x,imag=y;}
}a[MAXN],b[MAXN],c[MAXN];
Complex operator + (const Complex & x,const Complex & y){
return Complex(x.real+y.real, x.imag+y.imag);
}
Complex operator - (const Complex & x,const Complex & y){
return Complex(x.real-y.real, x.imag-y.imag);
}
Complex operator * (const Complex & x,const Complex & y){
return Complex(x.real*y.real-x.imag*y.imag, x.real*y.imag+x.imag*y.real);
}
int N,M,n,nn=1,bin=0,rev[MAXN];
double chart_sin[MAXN],chart_cos[MAXN];
void FFT(Complex* now,int len){
int gap=1,ngap;
Complex u,t;
for(int i=0;i<nn;i++)if(i<rev[i])swap(now[i],now[rev[i]]);
for(int s=1;s<=bin;s++){//length of the DFT sequence
ngap=gap,gap=gap<<1;
Complex unit = Complex(chart_cos[gap],chart_sin[gap]);
for(int k=0;k<nn;k+=gap){//where to start DFT
Complex w = Complex(1,0);
for(int j=0;j<ngap;j++){//iterate interval
t = w*now[k+j+ngap];
u=now[k+j];
now[k+j]=u+t;
now[k+j+ngap]=u-t;
w = w*unit;
}
}
}
return ;
}
void IDFT(Complex* now,int len){
int gap=1,ngap;
Complex u,t;
for(int i=0;i<nn;i++)if(i<rev[i])swap(now[i],now[rev[i]]);
for(int s=1;s<=bin;s++){//length of the IDFT sequence
ngap=gap,gap=gap<<1;
Complex unit = Complex(chart_cos[gap],-chart_sin[gap]);
for(int k=0;k<nn;k+=gap){//where to start IDFT
Complex w = Complex(1,0);
for(int j=0;j<ngap;j++){//iterate interval
t = w*now[k+j+ngap];
u=now[k+j];
now[k+j]=u+t;
now[k+j+ngap]=u-t;
w = w*unit;
}
}
}
return ;
}
inline int get_rev(int num){
int x=0;
for(int i=1;i<=bin;i++){
x=(x+(num&1))<<1;
num=num>>1;
}
return x>>1;
}
inline void pre(){
while(nn<=n){
nn=nn<<1;
bin++;
}
for(int i=0;i<nn;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bin-1));
//rev[i]=get_rev(i);
//printf("%d %d %d\n",i,rev[i],bin);
}
for(int i=1;i<=nn;i=i<<1){
chart_sin[i]=sin(2*PI/i);
chart_cos[i]=cos(2*PI/i);
}
return ;
}
int main(){
scanf("%d %d",&N,&M);
n=N+M;
for(int i=0;i<=N;i++)scanf("%lf",&a[i].real);
for(int i=0;i<=M;i++)scanf("%lf",&b[i].real);
pre();
FFT(a,nn);
FFT(b,nn);
for(int i=0;i<=nn;i++){
c[i]=a[i]*b[i];
}
IDFT(c,nn);
for(int i=0;i<=n;i++){
printf("%d ",int(c[i].real/nn+0.5));
}
return 0;
}