手撕FFT

手撕FFT

多项式简介

算法导论提供了全部理论基础:

先说分治:

我们在相乘时,按照未知项的奇偶性分开:

$A(x) = A^0(x) + A^1(x) $;

B ( x ) = B 0 ( x ) + B 1 ( x ) B(x) = B^0(x) + B^1(x) B(x)=B0(x)+B1(x);

A B = ( A 0 + x A 1 ) ( B 0 + x B 1 ) = A 0 B 0 + x ( A 1 B 0 + A 0 B 1 ) + x 2 A 1 B 1 AB = (A^0 + xA^1)(B^0 + xB^1) = A^0B^0 + x(A^1B^0 + A^0B^1) + x^2A^1B^1 AB=(A0+xA1)(B0+xB1)=A0B0+x(A1B0+A0B1)+x2A1B1;

由上式可得,我们可以通过分治算法把两个多项式折半,再计算四次多项式乘法并相加合并。

但此时 T ( n ) = 4 T ( n / 2 ) + f ( n ) T(n) = 4T(n/2) + f(n) T(n)=4T(n/2)+f(n),所以复杂度仍为 O ( n 2 ) O(n^2) O(n2);

但是 ( a x + b ) ( c x + d ) = a c x 2 + ( a d + b c ) x + b d (ax + b)(cx + d) = acx^2 + (ad + bc)x + bd (ax+b)(cx+d)=acx2+(ad+bc)x+bd,实际上只需要三次乘法就可以,所以我们可以使用这个方法减少一次乘法运算,此时 T ( n ) = 3 T ( n / 2 ) + f ( n ) T(n) = 3T(n/2) + f(n) T(n)=3T(n/2)+f(n);

我们得知多项式可以使用点值表示和插值表示两种形式;

我们使用拉格朗日插值求解方法可以将复杂度优化到 n 2 n^2 n2:

  • 选取 n n n个 x i x^i xi,带入点值,复杂度为 O ( n 2 ) O(n^2) O(n2);
  • 计算点值的卷积,复杂度为 O ( n ) O(n) O(n);
  • 插值计算系数向量,这一步是 O ( n 2 ) O(n^2) O(n2);

我们在此基础上通过选取复数单位根继续优化:

  • 考虑方程 z n = 1 z^n = 1 zn=1,因此在一个三角函数周期上取得n个方程复数根;
  • 相消定理,其实就是周期函数,为了限制右上角次数;
  • 折半定理,n次单位根的平方集合等于n/2次单位根的集合,显然成立,得到结论;
  • 求和引理,就是凑够了就是0;

再说DFT:

DFT就是将次数界为n的多项式A(x)在n次单位复数根上求值的过程;

y = D F T ( a ) y = DFT(a) y=DFT(a)

因此我们使用FFT利用单位根的特殊性质把DFT优化到 O ( n l o g n ) O(nlogn) O(nlogn):

  • 在分治中我们要计算的是 A 0 ( x 2 ) A^0(x^2) A0(x2),根据折半定理 ( ω 0 ) 2 . . . ( ω k ) 2 . . . (\omega^0)^2...(\omega^k)^2... (ω0)2...(ωk)2...,两两重复,所以是n/2个n/2次单位根;
  • 然后合并答案:计算只需 y i = y i 0 + ω i y i 1 , y ( i + n / 2 ) = y i 0 − ω i y i i yi = yi^0 + \omega^iyi^1, y(i + n/2) = yi^0 - \omega^iyi^i yi=yi0+ωiyi1,y(i+n/2)=yi0−ωiyii;
  • T ( n ) = 2 T ( n / 2 ) + f ( n ) , O ( n l o g n ) T(n) = 2T(n/2) + f(n), O(nlogn) T(n)=2T(n/2)+f(n),O(nlogn);

因为按照奇偶性计算,所以使用蝴蝶操作,将所有系数按照位置排列再迭代合并。

位反转排序

cpp 复制代码
for(int i = 0; i < n; i++){
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    if(i < rev[i]){
        swap(A[i], A[rev[i]]);
    }
}
  • 位反转数组 :使用位操作计算rev[i],将索引i的二进制表示进行反转。
  • 交换 :如果i小于rev[i],则交换A[i]A[rev[i]],实现数组的位反转排序。这是FFT算法中的关键步骤,有助于提高计算效率。

例子:位反转排序

假设我们有一个数组的长度为8(n=8n = 8n=8),其索引为0到7。我们的目标是将这些索引进行位反转。

1. 原始索引及其二进制表示
索引:   0   1   2   3   4   5   6   7
二进制: 000 001 010 011 100 101 110 111
2. 位反转过程

对于每个索引,我们将其二进制表示进行反转:

  • 0 -> 000 -> 000 -> 0
  • 1 -> 001 -> 100 -> 4
  • 2 -> 010 -> 010 -> 2
  • 3 -> 011 -> 110 -> 6
  • 4 -> 100 -> 001 -> 1
  • 5 -> 101 -> 101 -> 5
  • 6 -> 110 -> 011 -> 3
  • 7 -> 111 -> 111 -> 7
3. 反转结果

反转后的索引数组是:

索引:   0   4   2   6   1   5   3   7

应用位反转排序的FFT

假设我们有一个复数数组 AAA:

A: [A[0], A[1], A[2], A[3], A[4], A[5], A[6], A[7]]

经过位反转排序后,数组会变为:

A: [A[0], A[4], A[2], A[6], A[1], A[5], A[3], A[7]]

蝶形计算的基本形式

对于输入的两个复数 xxx 和 yyy,蝶形计算可以表示为:

输出 1 = x + ω ⋅ y 输出1=x+ω⋅y 输出1=x+ω⋅y

输出 2 = x − ω ⋅ y 输出2=x−ω⋅y 输出2=x−ω⋅y

其中, ω \omega ω 是旋转因子,通常是一个复数,表示特定的相位旋转,依赖于当前的计算阶段。

内循环进行蝶形运算

cpp 复制代码
for(int i = 0; i < n; i += mid << 1){
  • i循环遍历A,每次跳过mid << 1(即2 * mid),这保证了在进行蝶形运算时不会重叠。

计算蝶形操作

cpp 复制代码
for(int j = 0; j < mid; j++, omega *= temp){
  • 内部循环用于进行蝶形操作,j从0到mid-1,更新omega为当前的旋转因子。
cpp 复制代码
complex<double>x = A[i + j], y = omega * A[i + j + mid];
  • 取出当前需要计算的两个元素,x为前半部分,y为后半部分乘以旋转因子。
cpp 复制代码
A[i + j] = x + y;
A[i + j + mid] = x - y;
  • 更新数组A的值:
    • A[i + j]存储前半部分和后半部分的和(频域的合成)。
    • A[i + j + mid]存储前半部分和后半部分的差(频域的分离)。

函数 invert

cpp 复制代码
int invert(int n){
    int bit = 1;
    while((1 << bit) < n) bit++;
    return (1 << bit);
}
  • 该函数返回大于等于n的最小的2的幂次。
  • 通过位运算计算出2的幂次,确保FFT算法能够处理的长度是2的幂次。

函数 FFT

cpp 复制代码
void FFT(complex<double> *A, int n, int inv){
    int bit = 1;
    while((1 << bit) < n) bit++;
    for(int i = 0; i < n; i++){
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
        if(i < rev[i]){
            swap(A[i], A[rev[i]]);
        }
    }
    
    for(int mid = 1; mid < n; mid <<= 1){
        complex<double> temp(cos(Pi / mid), inv * sin(Pi / mid));
        for(int i = 0; i < n; i += mid << 1){
            complex<double> omega(1, 0);
            for(int j = 0; j < mid; j++, omega *= temp){
                complex<double>x = A[i + j], y = omega * A[i + j + mid];
                A[i + j] = x + y;
                A[i + j + mid] = x - y;
            }
        }
    }
}
  • 参数:

    • A:输入的复数数组。
    • n:数组长度。
    • inv:指示是进行正向FFT还是逆向FFT(1表示正向,-1表示逆向)。
  • 功能:

    1. 计算并存储rev数组,用于位反转。
    2. 使用蝶形操作对复数进行FFT计算。temp是旋转因子,根据当前的mid值计算出。
    3. 通过循环进行合并和计算,最终得到频域结果。
C 复制代码
#include <cstdio>
#include <complex>
using namespace std;
const int N = 1e7 + 1;
const double Pi = acos(-1);
int n, m, rev[N];
complex<double> F[N], G[N], H[N];

int invert(int n){
	int bit = 1;
	while((1 << bit) < n)bit++;
	return (1 << bit);
}

int getint(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9'){
		if(c == '-')f = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9'){
		x = (x << 1) + (x << 3) + c - '0';
		c = getchar();
	}
	return x * f;
}

void FFT(complex<double> *A, int n, int inv){
	int bit = 1;
	while((1 << bit) < n)bit++;
	for(int i = 0; i < n; i++){
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
		if(i < rev[i]){
			swap(A[i], A[rev[i]]);
		}
	}
	
	for(int mid = 1; mid < n; mid <<= 1){
		complex<double> temp(cos(Pi / mid), inv * sin(Pi / mid));
		for(int i = 0; i < n; i += mid << 1){
			complex<double> omega(1, 0);
			for(int j = 0; j < mid; j++, omega *= temp){
				complex<double>x = A[i + j], y = omega * A[i + j + mid];
				A[i + j] = x + y;
				A[i + j + mid] = x - y;
 			}
		}
	}
}

int main(){
	scanf("%d %d", &n, &m);
	for(int i = 0; i <= n; i++)F[i].real(getint());
	for(int i = 0; i <= m; i++)G[i].real(getint());
	//printf("get done\n");
	FFT(F, invert(n + m), 1);
	FFT(G, invert(n + m), 1);
	
	for(int i = 0; i <= invert(n + m); i++){
		H[i] = F[i] * G[i];
	}
	
	FFT(H, invert(n + m), -1);
	
	for(int i = 0; i <= n + m; i++){
		printf("%d ", (int)(H[i].real() / invert(n + m) + 0.5));
	}
}
相关推荐
jiao_mrswang40 分钟前
leetcode-18-四数之和
算法·leetcode·职场和发展
qystca1 小时前
洛谷 B3637 最长上升子序列 C语言 记忆化搜索->‘正序‘dp
c语言·开发语言·算法
薯条不要番茄酱1 小时前
数据结构-8.Java. 七大排序算法(中篇)
java·开发语言·数据结构·后端·算法·排序算法·intellij-idea
今天吃饺子1 小时前
2024年SCI一区最新改进优化算法——四参数自适应生长优化器,MATLAB代码免费获取...
开发语言·算法·matlab
是阿建吖!1 小时前
【优选算法】二分查找
c++·算法
王燕龙(大卫)1 小时前
leetcode 数组中第k个最大元素
算法·leetcode
不去幼儿园2 小时前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
Mr_Xuhhh2 小时前
重生之我在学环境变量
linux·运维·服务器·前端·chrome·算法
盼海3 小时前
排序算法(五)--归并排序
数据结构·算法·排序算法
网易独家音乐人Mike Zhou6 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot