Note -「Fast Fourier Transform」

cirnovsky /

☆ 0x00 前言

这篇主要是针对有基础的同学的,基础知识相信大家都学过。

这篇blog的诞生是因为myh要学FFT,甚至疯狂到了要找别人语音解答的地步。

然后我就想起远古时候WGY好像学过这么个东西,就写了篇blog出来给myh各位看,顺便复习一下。

说一下学了这东西的感悟吧。我觉得只要学了一次函数和三角函数就能看懂这篇。

在那个远古时代,周x健还没讲一次函数,我天天抱着啃没啃懂……

后来,周x健讲了一次函数,我又在物竞那嫖了些三角函数。

然后整个人心态都变了,几个小时下来感觉FFT也就那样。(甚至不如NTT有用

所以大家不要畏难,在座各位的数学都果断吊打WGY对吧。

李琰之前也在OJ发过FFT的note,但是递归转循环基本就是放了个代码,还留下了不少坑。

我这里也填了不少李琰的坑,但是一些基础的单位根性质的证明之类的东西我这里就懒得给了,自己考李琰的古吧。

因为时间原因一些可以从几何意义上来理解的东西我没给图,大家可以自己手%一下。

(好罢主要是给myh看的手动@ybmyh)

☆ 0x01 点值表示法

众所周知,一个

F(x)=i=0nai×xiF(x)=\sum_{i=0}^{n}a_{i}\times x^{i}

形式的 nn 次多项式可以在平面直角坐标系中被 n+1n+1 个点唯一的表示出来。

点值表示的两个多项式可以在线性时间复杂度中解决相乘,就是对应的 yy 乘起来。

但是暴力的把系数表示法转化为点值表示法依然是 Θ(n2)\Theta(n^{2}) 的。怎么办呢?

☆ 0x02 复数和单位根

说过不讲,我就放在这里方便我翻。

(a+bi)+(c+di)=(a+c)+(b+d)i(a+bi)+(c+di)=(a+c)+(b+d)i

(a+bi)(c+di)=(ac)+(bd)i(a+bi)-(c+di)=(a-c)+(b-d)i

(a+bi)×(c+di)=ac+adi+bcibd=(acbd)+(ad+bc)i(a+bi)\times(c+di)=ac+adi+bci-bd=(ac-bd)+(ad+bc)i

除法可以不用,其实也不用讲,自然而然的东西。

方程 xn=1x^{n}=1 的根,称作单位根用 ωnk\omega_{n}^{k} 表示。

kk 表示第 kknn 次单位根,从0开始标号,ωn0,ωn1,,ωnn1,\omega^{0}_{n},\omega^{1}_{n},\cdots,\omega^{n-1}_{n},。其中 ωn0=1\omega^{0}_{n}=1

虽然这样说,但是 knk\geq n 以及 k<0k < 0 的情况是被允许的。

原因看到后面就知道了。

从几何意义上来理解单位根即复数的坐标系单位圆的 nn 等分点。

复数相乘的性质:模长相乘,辐角相加。

模长指一个复数到原点的距离,t=a+bit=a+bi 的模长记作 t|t|

辐角指原点到点的连线与 xx 轴的正方向的夹角,记作 arg(a+bi)\arg(a+bi)

接下来列举需要用到的公式。

\label2ωnk=e2kπni=cos2kπni+isin2kπnωn0=1ωnk=ωkmodnωnk×ωnj=ωnk+j(ωn1)k=ωnkωpnpk=ωnkωnk+n2=ωnk\begin{align*}\label{2} & \omega^{k}_{n}=e^{\frac{2k\pi}{n}i}=\cos\frac{2k\pi}{n}i+i\sin\frac{2k\pi}{n} \tag{2.1} \\ & \omega^{0}_{n}=1 \tag{2.2} \\ & \omega^{k}_{n}=\omega^{k\operatorname{mod} n} \tag{2.3} \\ & \omega^{k}_{n}\times\omega^{j}_{n}=\omega^{k+j}_{n} \tag{2.4} \\ & (\omega^{1}_{n})^{k}=\omega^{k}_{n} \tag{2.5} \\ & \omega^{pk}_{pn}=\omega^{k}_{n} \tag{2.6} \\ & \omega^{k+\frac{n}{2}}_{n}=-\omega^{k}_{n} \tag{2.7} \\ \end{align*}

靠这段在vsc上显示不出来我自毙

☆ 0x03 继续研究多项式

我们设一个多项式

F(x)=i=0n1ai×xiF(x)=\sum_{i=0}^{n-1}a_{i}\times x^{i}

保证 n=2p+1n=2^{p}+1

我们按 ii 的奇偶性把 FF 分为两个部分

F(x)=i=0n1ai×xi=i=0n21a2i×x2i+i=0n21a2i+1×x2i+1F(x)=\sum_{i=0}^{n-1}a_{i}\times x^{i}=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}\times x^{2i}+\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}\times x^{2i+1}

继续定义

L(x)=i=0n21a2i×xiL(x)=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}\times x^{i}

R(x)=i=0n21a2i+1×xiR(x)=\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}\times x^{i}

也就是说

F(x)=i=0n1ai×xi=i=0n21a2i×x2i+i=0n21a2i+1×x2i+1=i=0n21a2i×(x2)i+i=0n21a2i+1×(x2)i×x=i=0n21a2i×x2i+i=0n21a2i+1×x2i+1=L(x2)+xR(x2)F(x)=\sum_{i=0}^{n-1}a_{i}\times x^{i}=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}\times x^{2i}+\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}\times x^{2i+1}=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}\times (x^{2})^{i}+\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}\times (x^{2})^{i}\times x=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}\times x^{2i}+\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}\times x^{2i+1}=L(x^{2})+xR(x^{2})

我们可以代入一个数进去。一般我们肯定想着代个看起来可爱的数字。

看看,这就是我等凡人与傅里叶这等神仙的区别。人家代入的是什么?没错,单位根!(不然我TM前面罗列一大堆单位根的性质干嘛)

F(x)=L(x2)+xR(x2)F(x)=L(x^{2})+xR(x^{2})

F(ωnk)=L(ωn2k)+ωnkR(ωn2k)F(\omega^{k}_{n})=L(\omega^{2k}_{n})+\omega^{k}_{n}R(\omega^{2k}_{n})

F(ωnk)=L(ω2×12n2k)+ωnkR(ω2×12n2k)F(\omega^{k}_{n})=L(\omega^{2k}_{2\times \frac{1}{2}n})+\omega^{k}_{n}R(\omega^{2k}_{2\times \frac{1}{2}n})

由公式 (2.6)

F(\omega^{k}_{n})=L(\omega^{k}_{\frac{1}{2}n})+\omega^{k}_{n}R(\omega^{k}_{\frac{1}{2}n}) \tag{3.1}

回到

F(x)=L(x2)+xR(x2)F(x)=L(x^{2})+xR(x^{2})

此时我们代入 ωnk+n2\omega^{k+\frac{n}{2}}_{n}

同理可得

F(\omega^{k+\frac{n}{2}}_{n})=L(\omega^{k}_{\frac{n}{2}})-R(\omega^{k}_{\frac{n}{2}}) \tag{3.2}

可以发现 (3.1) 和 (3.2) 只差了符号,也就是说只要知道了 L(ωn2k)L(\omega^{k}_{\frac{n}{2}})R(ωn2k)R(\omega^{k}_{\frac{n}{2}}) 我们就可以同时得到 F(ωnk)F(\omega^{k}_{n})F(ωnk+n2)F(\omega^{k+\frac{n}{2}}_{n})。然后就递归求解。

这样我们就可以在 Θ(nlog2n)\Theta(n\log_{2}n) 求取多项式的点值表示了。

算法名叫DFT。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <complex>

using namespace std;

const double PI = acos(-1);
const int MAXN = 1e6 + 3e5 + 5;
int n;

struct Complex {
	double real;
	double imag;
	Complex(double t_real = 0, double t_imag = 0) { real = t_real, imag = t_imag; }
	Complex operator + (Complex const& rhs) const {
		return Complex(real + rhs.real, imag + rhs.imag);
	}
	Complex operator - (Complex const& rhs) const {
		return Complex(real - rhs.real, imag - rhs.imag);
	}
	Complex operator * (Complex const& rhs) const {
		return Complex(real * rhs.real - imag * rhs.imag, real * rhs.imag + imag * rhs.real);
	}
	void to_real(const double t_real) {
		real = t_real;
	}
	void to_imag(const double t_imag) {
		imag = t_imag;
	}
	double to_real() {
		return real;
	}
	double to_imag() {
		return imag;
	}
} F[MAXN << 2], t[MAXN << 2];

void dft(Complex *f, int __n) {
	if (__n == 1) return ;
	Complex *L = f;
	Complex *R = f + (__n >> 1);
	for (int k = 0; k < __n; ++k) t[k] = f[k];
	for (int k = 0; k < (__n >> 1); ++k) L[k] = t[k << 1], R[k] = t[k << 1 | 1];	
	dft(L, (__n >> 1));
	dft(R, (__n >> 1));
	Complex omega;
	Complex now;
	omega.to_real(cos(2 * PI / __n));
	omega.to_imag(sin(2 * PI / __n));
	now.to_real(1);
	now.to_imag(0);
	for (int k = 0; k < (__n >> 1); ++k) {
		t[k] = L[k] + now * R[k];
		t[k + (__n >> 1)] = L[k] - now * R[k];
		now = now * omega;
	}
	for (int k = 0; k < __n; ++k) f[k] = t[k];
}

signed main() {
	scanf("%d", &n);
	int temp = n;
	double x;
	for (n = 1; n < temp; n <<= 1) ;
	for (int i = 0; i < temp; ++i) scanf("%lf", &x), F[i].to_real(x), F[i].to_imag(0);
	dft(F, n);
	for (int i = 0; i < n; ++i) printf("(%.2lf %.2lf)\n", F[i].to_real(), F[i].to_imag());
	return 0;
}

subarashi

但是我们现在求到的只是一堆没用的点值,还需要求到的点值表示还原成系数表示。

结论是把代入的 ωnk\omega^{k}_{n} 换成 ωnk\omega^{-k}_{n} 然后除以 nn

即DFT的逆运算IDFT。

IDFT的证明比较繁琐,涉及到分类讨论。由于我最近被数学作业的多答案讨论和智障珠的60种情况毒瘤了,故懒得给出证明。反正我相信看我blog的人人均会单位根反演

我们记 F(F(x))\mathcal{F}(F(x))F(x)F(x) 的离散傅里叶变换/傅里叶变换,F(F(x))\mathcal{F'}(F(x))F(x)F(x) 的逆离散傅里叶变换/傅里叶变换。

用看起来很高大上很nb的数学语言表示就是

G=F(F(x))G=\mathcal{F}(F(x))

n×fk=i=0n1ωnkigin\times f_{k}=\sum_{i=0}^{n-1}\omega^{-ki}_{n}g_{i}

其中 fif_{i}gig_{i} 分别为 FFGG 的第 ii 项系数。

我们只需要改一下代码就好了。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <complex>

using namespace std;

const double PI = acos(-1);
const int MAXN = 1e6 + 3e5 + 5;
int n, m;
struct Complex {
	double real;
	double imag;
	Complex(double t_real = 0, double t_imag = 0) { real = t_real, imag = t_imag; }
	Complex operator + (Complex const& rhs) const {
		return Complex(real + rhs.real, imag + rhs.imag);
	}
	Complex operator - (Complex const& rhs) const {
		return Complex(real - rhs.real, imag - rhs.imag);
	}
	Complex operator * (Complex const& rhs) const {
		return Complex(real * rhs.real - imag * rhs.imag, real * rhs.imag + imag * rhs.real);
	}
	void to_real(const double t_real) {
		real = t_real;
	}
	void to_imag(const double t_imag) {
		imag = t_imag;
	}
	double to_real() {
		return real;
	}
	double to_imag() {
		return imag;
	}
} A[MAXN << 2], B[MAXN << 2], t[MAXN << 2];

void dft(Complex *f, int __n, int flag) {
	if (__n == 1) return ;
	Complex *L = f;
	Complex *R = f + (__n >> 1);
	for (int k = 0; k < __n; ++k) t[k] = f[k];
	for (int k = 0; k < (__n >> 1); ++k) L[k] = t[k << 1], R[k] = t[k << 1 | 1];	
	dft(L, (__n >> 1), flag);
	dft(R, (__n >> 1), flag);
	Complex omega;
	Complex now;
	omega.to_real(cos(2 * PI / __n));
	omega.to_imag(sin(2 * PI / __n) * flag);
	now.to_real(1);
	now.to_imag(0);
	for (int k = 0; k < (__n >> 1); ++k) {
		t[k] = L[k] + now * R[k];
		t[k + (__n >> 1)] = L[k] - now * R[k];
		now = now * omega;
	}
	for (int k = 0; k < __n; ++k) f[k] = t[k];
}

signed main() {
	scanf("%d %d", &n, &m);
	double x;
	for (int i = 0; i <= n; ++i) scanf("%lf", &x), A[i].to_real(x), A[i].to_imag(0);
	for (int i = 0; i <= m; ++i) scanf("%lf", &x), B[i].to_real(x), B[i].to_imag(0);
	for (m += n, n = 1; n <= m; n <<= 1) ;
	dft(A, n, 1);
	dft(B, n, 1);
	for (int i = 0; i < n; ++i) A[i] = A[i] * B[i];
	dft(A, n, -1);
	for (int i = 0; i <= m; ++i) printf("%d ", (int)(A[i].real / n + 0.49));
	return 0;
}

递归版常数过大,我们想想能不能把递归转为循环(迭代)。

这里给一个结论,给出一个序列。比如 0 1 2 3 4 5 6 7\texttt{0 1 2 3 4 5 6 7}

对其进行DFT后:0 4 2 6 1 5 3 7\texttt{0 4 2 6 1 5 3 7}。多试几组可以发现DFT后每个位置数是原序列对应位置上的数的二进制反转后的结果。

0 1 2 3 4 5 6 7\texttt{0 1 2 3 4 5 6 7}

(000) (001) (010) (011) (100) (101) (110) (111)\texttt{(000) (001) (010) (011) (100) (101) (110) (111)}

(000) (100) (010) (110) (001) (101) (011) (111)\texttt{(000) (100) (010) (110) (001) (101) (011) (111)}

0 4 2 6 1 5 3 7\texttt{0 4 2 6 1 5 3 7}

证明也好证,留作思考吧。

然后我们就可以预处理出序列DFT后的位置,然后向上合并。就不用从上至下递归了。

具体来说,我们设 revirev_{i} 为数字 ii 的二进制翻转,limlim 为最多的二进制位数。

翻转操作相当于把当前数的二进制最后一位接到之前部分翻转的前面。

之前部分的翻转即 revishr1shr1rev_{i\operatorname{shr} 1}\operatorname{shr} 1

其中 shr\operatorname{shr} 相当于右移操作,shl\operatorname{shl} 同理。

然后判一下最后一位,是1的话就让 2lim12^{lim-1}revishr1shr1rev_{i\operatorname{shr} 1}\operatorname{shr} 1 按位与。因为 2p2^{p} 的二进制位始终是1后面跟着 pp 个0。

这里建议自己手推一下。

int lim = 0;
while ((1 << lim) < n) ++lim;
for (int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));

完整代码

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <complex>

using namespace std;

const double PI = acos(-1);
const int MAXN = 1e6 + 3e5 + 5;
int n, m;

struct Complex {
	double real;
	double imag;
	Complex(double t_real = 0, double t_imag = 0) { real = t_real, imag = t_imag; }
	Complex operator + (Complex const& rhs) const {
		return Complex(real + rhs.real, imag + rhs.imag);
	}
	Complex operator - (Complex const& rhs) const {
		return Complex(real - rhs.real, imag - rhs.imag);
	}
	Complex operator * (Complex const& rhs) const {
		return Complex(real * rhs.real - imag * rhs.imag, real * rhs.imag + imag * rhs.real);
	}
	void to_real(const double t_real) {
		real = t_real;
	}
	void to_imag(const double t_imag) {
		imag = t_imag;
	}
	double to_real() {
		return real;
	}
	double to_imag() {
		return imag;
	}
} A[MAXN << 2], B[MAXN << 2], t[MAXN << 2];
int rev[MAXN << 2];

void get_rev() {
	int lim = 0;
	while ((1 << lim) < n) ++lim;
	for (int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}

void fft(Complex *f, int __n, int flag) {
	for (int i = 0; i < n; ++i) if (i < rev[i]) swap(f[i], f[rev[i]]);
		for (int mid = 1; mid < lim; mid <<= 1) {
			Complex omega;
			omega.to_real(cos(PI / mid));
			omega.to_imag(sin(PI / mid) * flag);
			for (int i = 0; i < n; i += (mid << 1)) {
				Complex now;
				now.to_real(1);
				now.to_imag(0);
				for (int j = 0; j < mid; ++j) {
					Complex first = f[i + j];
					Complex second = now * f[i + j + mid];
					f[i + j] = first + second;
					f[i + j + mid] = first - second;
					now = now * omega;
				}
			}
		}
}

signed main() {
	scanf("%d %d", &n, &m);
	double x;
	for (int i = 0; i <= n; ++i) scanf("%lf", &x), A[i].to_real(x), A[i].to_imag(0);
	for (int i = 0; i <= m; ++i) scanf("%lf", &x), B[i].to_real(x), B[i].to_imag(0);
	for (m += n, n = 1; n <= m; n <<= 1) ;
	get_rev();
	fft(A, n, 1);
	fft(B, n, 1);
	for (int i = 0; i < n; ++i) A[i] = A[i] * B[i];
	fft(A, n, -1);
	for (int i = 0; i <= m; ++i) printf("%d ", (int)(A[i].real / n + 0.49));
	return 0;
}