快速傅里叶变换
2017年8月3日网上关于快速傅里叶变换的内容讲的不少,但<算法导论>讲的东西才是经典啊.
两个 n 次多项式相加的最直接方法所需时间为 \\Theta(n),但相乘的最直接方法所需为 \\Theta(n^2).用快速傅里叶变换(FFT),能将多项式相乘的时间复杂度降低为 \\Theta(nlgn)
多项式的表示
如果一个多项式 A(x) 的最高次的非零系数是 a_k,则称 A(x) 的次数是 k ,记 degree(A) = k.也就是说 k 为多项式中最高次的次数.
系数表示
对这个多项式:
A(x) = a_{n}x^{n}+a_{n-1}x^{n-1}+\\cdots+a_{1}x^1+a_0
其系数表达式为 \\{a_0, a_1, \\cdots, a_n\\}
点值表示
一个 n 次多项式 A(x) 的点值表示就是由 n + 1 个不同的点值所组成的集合
{(x_0, y_0), (x_1, y_1),\\cdots,(x_n, y_n)}
比如,对于一个二次函数,只需要知道三个不同的点就可以知道这个二次函数的图像了.从这里也可以看出,一个多项式可以有很多不同的点值表示.
<算法导论>上的说法是次数界为 n 的多项式,对于次数界为 n 的多项式,其次数是 0 ~ n – 1的任何整数.而 n 次多项式 是指最高次次数为 n 的多项式
求值的逆(从一个多项式的点值表达确定其系数表达形式)成为插值.
单位复数根
n次单位复数根 是满足 \\omega^n=1的复数w.n次单位复数根恰好有n个(公式中的i是值虚数单位,即i=\\sqrt{-1}):
e^{\\frac{2{\\pi}ik}{n}},(k=0,1,\\cdots,n-1)
由欧拉公式:
e^{ix}=\\cos{x}+i\\sin{x}
记 \\omega_n=e^{\\frac{2{\\pi}i}{n}},则对应的n个n次单位根可表示为:
\\omega_n^0,\\omega_n^1,\\cdots,\\omega_n^{n-1}
对多项式的操作,用点值表达很便利.
对于加法
如果 C(x) = A(x) + B(x),则对任意点 x_k,满足 C(x_k) = A(x_k) + B(x_k).就是说如果:
A 的点值表达为:{(x_0, y_0), (x_1, y_1),\\cdots,(x_n, y_n)}
B 的点值表达为:{(x_0, y_0^1), (x_1, y_1^1),\\cdots,(x_n, y_n^1)}
则C 的点值表达为:{(x_0,y_0+y_0^1), (x_1,y_1+ y_1^1),\\cdots,(x_n,y_n+ y_n^1)}
时间复杂度:\\Theta(n)
对于乘法
如果 C(x) = A(x) B(x),则对任意点 x_k,满足 C(x_k) = A(x_k)B(x_k),且有 degree(C)=degree(A)+degree(B).
同样的C的点值表达为:{(x_0,y_0y_0^1), (x_1,y_1 y_{1}^1),\\cdots,(x_{2n-1},y_{2n-1} y_{2n-1}^1)}
系数表示的多项式快速乘法
FFT 实例
两个大整数相乘:
#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 50005;
int rev(int id, int len)
{
int res = 0;
for (int i = 0; (1 << i) < len; i++) {
res <<= 1;
if (id & (1 << i))
res |= 1;
}
return res;
}
void fft(vector<complex<double> > &a, int len, int dft)
{
vector<complex<double> > tmp;
tmp.resize(len);
for (int i = 0; i < len; i++) {
tmp[rev(i, len)] = a[i];
}
for (int s = 1; (1 << s) <= len; s++) {
int m = (1 << s);
complex<double> wm = complex<double>(cos(dft * 2 * pi / m), sin(dft * 2 * pi / m));
for (int k = 0; k < len; k += m) {
complex<double> w = complex<double>(1, 0);
for (int j = 0; j < (m >> 1); j++) {
complex<double> t = w * tmp[k + j + (m >> 1)];
complex<double> u = tmp[k + j];
tmp[k + j] = u + t;
tmp[k + j + (m >> 1)] = u - t;
w = w * wm;
}
}
}
if (dft == -1) {
for (int i = 0; i < len; i++) {
tmp[i] = complex<double>(tmp[i].real() / len, tmp[i].imag() / len);
}
}
a = tmp;
}
vector<complex<double> > a, b;
vector<int> ans;
char stra[N], strb[N];
int main()
{
while (~scanf("%s %s", stra, strb)) {
a.clear();
b.clear();
ans.clear();
int lena = strlen(stra), lenb = strlen(strb);
int la = 0, lb = 0;
while ((1 << la) < lena)
la++;
while ((1 << lb) < lenb)
lb++;
int len = (1 << (max(la, lb) + 1));
for (int i = 0; i < len; i++) {
if (i < lena) {
a.push_back(complex<double>(stra[lena - i - 1] - '0', 0));
}
if (i < lenb)
b.push_back(complex<double>(strb[lenb - i - 1] - '0', 0));
}
a.resize(len);
b.resize(len);
fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i < len; i++) {
a[i] = a[i] * b[i];
}
fft(a, len, -1);
for (int i = 0; i < len; i++) {
ans.push_back(int(a[i].real() + 0.5));
}
for (int i = 0; i < len - 1; i++) {
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
bool flag = 0;
for (int i = len - 1; i >= 0; i--) {
if (ans[i]) {
printf("%d", ans[i]);
flag = 1;
} else if (flag || i == 0) {
printf("0");
}
}
putchar('\n');
}
return 0;
}