快速傅里叶变换

2017年8月3日

网上关于快速傅里叶变换的内容讲的不少,但<算法导论>讲的东西才是经典啊.

两个 n 次多项式相加的最直接方法所需时间为 \\Theta(n),但相乘的最直接方法所需为 \\Theta(n^2).用快速傅里叶变换(FFT),能将多项式相乘的时间复杂度降低为 \\Theta(nlgn)

多项式的表示

如果一个多项式 A(x) 的最高次的非零系数是 a_k,则称 A(x) 的次数是 k ,记 degree(A) = k.也就是说 k 为多项式中最高次的次数.

系数表示

  对这个多项式:

A(x) = a_{n}x^{n}+a_{n-1}x^{n-1}+\\cdots+a_{1}x^1+a_0

  其系数表达式为 \\{a_0, a_1, \\cdots, a_n\\}

点值表示

  一个 n 次多项式 A(x) 的点值表示就是由 n + 1 个不同的点值所组成的集合

{(x_0, y_0), (x_1, y_1),\\cdots,(x_n, y_n)}

比如,对于一个二次函数,只需要知道三个不同的点就可以知道这个二次函数的图像了.从这里也可以看出,一个多项式可以有很多不同的点值表示.

  <算法导论>上的说法是次数界为 n 的多项式,对于次数界为 n 的多项式,其次数是 0 ~ n – 1的任何整数.而 n 次多项式 是指最高次次数为 n 的多项式

  求值的逆(从一个多项式的点值表达确定其系数表达形式)成为插值.

单位复数根

  n次单位复数根 是满足 \\omega^n=1的复数w.n次单位复数根恰好有n个(公式中的i是值虚数单位,即i=\\sqrt{-1}):

e^{\\frac{2{\\pi}ik}{n}},(k=0,1,\\cdots,n-1)

  由欧拉公式:

e^{ix}=\\cos{x}+i\\sin{x}

  记 \\omega_n=e^{\\frac{2{\\pi}i}{n}},则对应的n个n次单位根可表示为:

\\omega_n^0,\\omega_n^1,\\cdots,\\omega_n^{n-1}

  对多项式的操作,用点值表达很便利.

对于加法

如果 C(x) = A(x) + B(x),则对任意点 x_k,满足 C(x_k) = A(x_k) + B(x_k).就是说如果:

 A 的点值表达为:{(x_0, y_0), (x_1, y_1),\\cdots,(x_n, y_n)}
 B 的点值表达为:{(x_0, y_0^1), (x_1, y_1^1),\\cdots,(x_n, y_n^1)}
则C 的点值表达为:{(x_0,y_0+y_0^1), (x_1,y_1+ y_1^1),\\cdots,(x_n,y_n+ y_n^1)}
 时间复杂度:\\Theta(n)

对于乘法

如果 C(x) = A(x) B(x),则对任意点 x_k,满足 C(x_k) = A(x_k)B(x_k),且有 degree(C)=degree(A)+degree(B)
 同样的C的点值表达为:{(x_0,y_0y_0^1), (x_1,y_1 y_{1}^1),\\cdots,(x_{2n-1},y_{2n-1} y_{2n-1}^1)}

系数表示的多项式快速乘法

FFT 实例

两个大整数相乘:

#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 50005;
int rev(int id, int len)
{
  int res = 0;
  for (int i = 0; (1 << i) < len; i++) {
    res <<= 1;
    if (id & (1 << i))
      res |= 1;
  }
  return res;
}
void fft(vector<complex<double> > &a, int len, int dft)
{
  vector<complex<double> > tmp;
  tmp.resize(len);
  for (int i = 0; i < len; i++) {
    tmp[rev(i, len)] = a[i];
  }
  for (int s = 1; (1 << s) <= len; s++) {
    int m = (1 << s);
    complex<double> wm = complex<double>(cos(dft * 2 * pi / m), sin(dft * 2 * pi / m));
    for (int k = 0; k < len; k += m) {
      complex<double> w = complex<double>(1, 0);
      for (int j = 0; j < (m >> 1); j++) {
        complex<double> t = w * tmp[k + j + (m >> 1)];
        complex<double> u = tmp[k + j];
        tmp[k + j] = u + t;
        tmp[k + j + (m >> 1)] = u - t;
        w = w * wm;
      }
    }
  }
  if (dft == -1) {
    for (int i = 0; i < len; i++) {
      tmp[i] = complex<double>(tmp[i].real() / len, tmp[i].imag() / len);
    }
  }
  a = tmp;
}
vector<complex<double> > a, b;
vector<int> ans;
char stra[N], strb[N];
int main()
{
  while (~scanf("%s %s", stra, strb)) {
    a.clear();
    b.clear();
    ans.clear();
    int lena = strlen(stra), lenb = strlen(strb);
    int la = 0, lb = 0;
    while ((1 << la) < lena)
      la++;
    while ((1 << lb) < lenb)
      lb++;
    int len = (1 << (max(la, lb) + 1)); 
    for (int i = 0; i < len; i++) {
      if (i < lena) {
        a.push_back(complex<double>(stra[lena - i - 1] - '0', 0));
      }
      if (i < lenb)
        b.push_back(complex<double>(strb[lenb - i - 1] - '0', 0));
    }
    a.resize(len);
    b.resize(len);
    fft(a, len, 1);
    fft(b, len, 1);
    for (int i = 0; i < len; i++) {
      a[i] = a[i] * b[i];
    }
    fft(a, len, -1);
    for (int i = 0; i < len; i++) {
      ans.push_back(int(a[i].real() + 0.5));
    }
    for (int i = 0; i < len - 1; i++) {
      ans[i + 1] += ans[i] / 10;
      ans[i] %= 10;
    }
    bool flag = 0;
    for (int i = len - 1; i >= 0; i--) {
      if (ans[i]) {
        printf("%d", ans[i]);
        flag = 1;
      } else if (flag || i == 0) {
        printf("0");
      }
    }
    putchar('\n');
  }
  return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注