Algorithm: Polynomial Multiplication — Fast Fourier Transform / Number-Theoretic Transform (English version)

Intro:

This blog will start with plain multiplication, go through Divide-and-conquer multiplication, and reach FFT and NTT.

The aim is to enable the reader (and myself) to fully understand the idea.

Template question entrance: Luogu P3803 【模板】多项式乘法(FFT)


Plain multiplication

Assumption: Two polynomials are (A(x)=sum_{i=0}^{n}a_ix^i,B(x)=sum_{i=0}^{m}b_ix^i)

Prerequisite knowledge:

Knowledge of junior high school mathematics

The simplest method is to multiply term by term and then combine like terms, written as the formula:

If (C(x)=A(x)B(x)), then (C(x)=sum_{i=0}^{n+m}c_ix^i), where (c_i=sum_{j=0}^ia_jb_{i-j}).

So a plain multiplication is generated, see the code ((b) array omitted with some useless techniques).

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (2000010)
int n,m,a[N],b,c[N];
signed main(){
	Rd(n),Rd(m);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m){Rd(b);Frn1(j,0,n)c[i+j]+=b*a[j];}
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}

Time complexity: (O(nm)) (If(m=O(n)), then (O(n^2)))

Memory complexity: (O(n))

Results:

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

Expected, so we need to optimize it.


Divide-and-conquer multiplication (Fake)

P.s This part describes the Divide-and-conquer method of FFT, which is still different from the exact FFT, so you can skip it if you have already mastered the Divide-and-conquer idea.

Let (n) be the smallest positive integer power of (2) that is strictly greater than both the degrees of (A(x),B(x)), and we write (A(x)=sum_{i=0}^{n-1}a_ix^i,B(x)=sum_{i=0}^{n-1}b_ix^i), where the unexisted coefficients are made (0).

Prerequisite knowledge:

  The idea of Divide-and-conquer

Now consider how to optimize multiplication.

Try to separate two polynomials according to the parity of the index of (x)

(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2),B(x)=B^{[0]}(x^2)+xB^{[1]}(x^2)),

where (A^{[0]}(x)=sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=sum_{i=0}^{n/2-1}a_{2i+1}x^i), and (B^{[0]}(x)) and (B^{[1]}(x)) are similar.

Therefore, the two polynomials are split into four polynomials, each with degree (<n/2).

We let (A=A(x),A^{[0]}=A^{[0]}(x^2),A^{[1]}=A^{[1]}(x^2)), and similar for (B) and others,

then (AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}).

A Divide-and-conquer algorithm can be found here: split two polynomials in half, then recursively do (4) polynomial multiplications, and finally combine them together (polynomial addition is (O(n)) anyway)

P.s As (A^{[0]}=A^{[0]}(x^2)) and (A^{[1]}=A^{[1]}(x^2)), the combination process is alternating. Here is the code. (In the code, the (n) above is replaced by the variable s, and vector is used to save memory)

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s; 
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mlt(Vct&a,Vct&b,Vct&c,int n);
signed main(){
	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	mlt(a,b,c,s);
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c,int n){
	int n2(n>>1);
	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
	if(n==1){c[0]=a[0]*b[0];return;}
	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
	mlt(a0,b0,ab0,n2),mlt(a1,b1,ab1,n2);
	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
	mlt(a0,b1,ab0,n2),mlt(a1,b0,ab1,n2),add(ab0,ab1,abm);
	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}

Results:

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

even worse

Why's that? Because the Time complexity is still (O(n^2)).

(textit{Proof. } T(n)=4T(n/2)+f(n)), in which (f(n)=O(n)) the time complexity of polynomial addition.

Using the Master Theorem with (a=4,b=2,log_ba=log_2 4=2>1), we have (T(n)=O(n^{log_ba})=O(n^2)).

So, let's continue optimizing


Divide-and-conquer multiplication (Real)

Let's consider how to optimize the "fake" one.

An intro question: Try to find an algorithm to multiply linear expressions (ax+b) and (cx+d) with only (3) multiplication steps.

Let's expand the multiplication: ((ax+b)(cx+d)=acx^2+(ad+bc)x+bd), there seems to be (4) multiplication steps used.

Hence, if we can only use (3) multiplication steps, then (ad+bc) should cost only one.

Let's add all coefficients together: (ac+ad+bc+bd=(a+b)(c+d)),

and here is the answer! Use (3) multiplication steps to calculate (ac,bd,(a+b)(c+d)) respectively, and the (x) coefficient is just (ad+bc=(a+b)(c+d)-ac-bd)

Let's go back to the original question

As (AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}),

we can use the similar method to reduce one multiplication step: (A^{[1]}B^{[0]}+A^{[0]}B^{[1]}=(A^{[0]}+A^{[1]})(B^{[0]}+B^{[1]})-A^{[0]}B^{[0]}-A^{[1]}B^{[1]})

Here is the code:

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mns(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]-b[i];}
void mlt(Vct&a,Vct&b,Vct&c);
signed main(){
	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	mlt(a,b,c);
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c){
	int n(a.size()),n2(a.size()>>1);
	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
	if(n==1){c[0]=a[0]*b[0];return;}
	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
	mlt(a0,b0,ab0),mlt(a1,b1,ab1);
	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
	add(a0,a1,a0),add(b0,b1,b0),mlt(a0,b0,abm),mns(abm,ab0,abm),mns(abm,ab1,abm);
	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}

Results

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

Better than fake DC multiplication, but even worse than plain multiplication...

Let's calculate the time complexity of this algorithm:

(T(n)=3T(n/2)+f(n)), in which (f(n)=O(n)).

Using Master Theorem with (a=3,b=2,log_ba=log_2 3approx1.58>1), so (T(n)=O(n^{log_ba})=O(n^{log_2 3})).

Hmm...so why is it even worse than plain multiplication?

Reason 1. The constant factor of DC multiplication is too high.

Reason 2. In (#5) test case, we have (n=1,m=3cdot 10^6), then (O(n^{log_2 3})) is really worse than (O(nm))...

So, our FFT is eventually coming!


Fast Fourier Transform

Fairly Frightening Transform

Let (n) be the smallest positive integer power of (2) greater than (deg A(x)+deg B(x)) and we write (A(x)=sum_{i=0}^{n-1}a_ix^i,B(x)=sum_{i=0}^{n-1}b_ix^i).

Prerequisite knowledge:

  The idea of Divide-and-conquer

  Complex number basics

Linear algebra basics (not strictly required)

Part 1: To representations of the polynomial

1. Coefficient expressions

For a polynomial (A(x)=sum_{i=0}^{n-1}a_ix^i), its coefficient expression is a vector (pmb{a}=left[begin{matrix}a_0\a_1\vdots\a_{n-1}end{matrix} right])

In coefficient expressions, the time complexities of the following methods are:

  1. Evaluation at a point: (O(n))

  2. Addition: (O(n))

  3. Multiplication: plain (O(n^2)), DC ((n^{log_2 3}))

P.s When calculating polynomial multiplication (C(x)=A(x)B(x)), the corresponding coefficient expression (pmb{c}) is defined as the convolution of (pmb{a}) and (pmb{b}), written as (pmb{c}=pmb{a}bigotimespmb{b}).

2. Point-valued expressions

The point-valued expression of a polynomial (A(x)) with (deg A<n) is a set of (n) points: ({(x_0,y_0),(x_1,y_1),cdots,(x_{n-1},y_{n-1})})

We can use (n) evaluations to convert a coefficient expression to a point-valued expression with a list of ((x_0,x_1,cdots,x_{n-1})) in time complexity of (O(n^2)) as shown:

(left[begin{matrix}1&x_0&x_0^2&cdots&x_0^{n-1}\1&x_1&x_1^2&cdots&x_1^{n-1}\vdots&vdots&vdots&ddots&vdots\1&x_{n-1}&x_{n-1}^2&cdots&x_{n-1}^{n-1}end{matrix} right]left[begin{matrix}a_0\a_1\vdots\a_{n-1}end{matrix} right]=left[begin{matrix}y_0\y_1\vdots\y_{n-1}end{matrix} right])

The matrix is written as (V(x_0,x_1,cdots,x_{n-1})), named Vandermonde matrix, so the formula is simplified to (V(x_0,x_1,cdots,x_{n-1})pmb{a}=pmb{y}).

Using Lagrangian formulas, a point-valued expression can be converted back into a coefficient expression in (O(n^2)) time, a process called interpolation.

With two polynomials in point-valued expressions with the same list of ((x_0,cdots,x_{n-1})), the time complexity of following methods are:

  1. Addition: (O(n)) (Adding the (y_i) value respectively)

  2. Multiplication (O(n)) (similar)

This is one central idea of FFT powered polynomial multiplication: with carefully chosen (x_i) values, we can achieve evaluation in (O(nlog n)), multiplication in (O(n)), and finally interpolation in (O(nlog n)).

So what are those (x_i) values?

Part 2: Complex roots of unity

The (n)-th roots of unity are exactly (n) complex numbers (omega) that satisfy (omega^n=1), written as:

(omega_n^k=e^{2pi ik/n}=cos(2pi k/n)+isin(2pi k/n)).

We can plot (n)-th roots of unity as (n) vertices of a regular (n)-gon inscribed in the unit circle on the complex plane. For example, the following graph shows the (8)-th roots of unity.

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

There is a pattern: (omega_n^jomega_n^k=omega_n^{j+k}=omega_n^{(j+k)mod n}). Specifically, (omega_n^{-1}=omega_n^{n-1}).

Three other important lemmas.

(text{Lemma 1. }) For all integers (ngeqslant 0,kgeqslant 0,d>0), we have (omega_{dn}^{dk}=omega_n^k).

(textit{Proof. }omega_{dn}^{dk}=(e^{2pi i/dn})^{dk}=(e^{2pi i/n})^k=omega_n^k.square)

(text{Lemma 2. }) For all even number (n) and integer (k), we have ((omega_n^k)^2=(omega_n^{k+n/2})^2=omega_{n/2}^k).

(textit{Proof. }(omega_n^k)^2=omega_n^{2k},(omega_n^{k+n/2})^2=omega_n^{2k+n}=omega_n^{2k}). Lastly, (omega_n^{2k}=omega_{n/2}^k) by (text{Lemma 1}.square)

(text{Lemma 3. }) For all integers (n,kgeqslant 0) such that (nnmid k), we have (sum_{j=0}^{n-1}(omega_n^k)^j=0).

(textit{Proof. }) When (nnmid k), we have (omega_n^kneq 1), so (sum_{j=0}^{n-1}(omega_n^k)^j=frac{1-(omega_n^k)^n}{1-omega_n^k}=frac{1-omega_n^{nk}}{1-omega_n^k}=frac{1-1}{1-omega_n^k}=0.square) (Question: why is (nnmid k) necessary?)

The above properties of roots of unity are the essence of FFT optimization.

Part 3: Discrete Fourier Transform

Recall the definition of (n), which is a power of (2). DFT is just the evaluation of coefficient expressed (A(x)) on (n)-th roots of unity. We write the Vandermonde matrix as

(V_n=V(omega_n^0,omega_n^1,cdots,omega_n^{n-1})=left[begin{matrix}1&1&1&1&cdots&1\1&omega_n&omega_n^2&omega_n^3&cdots&omega_n^{n-1}\1&omega_n^2&omega_n^4&omega_n^6&cdots&omega_n^{2(n-1)}\1&omega_n^3&omega_n^6&omega_n^9&cdots&omega_n^{3(n-1)}\vdots&vdots&vdots&vdots&ddots&vdots\1&omega_n^{n-1}&omega_n^{2(n-1)}&omega_n^{3(n-1)}&cdots&omega_n^{(n-1)(n-1)}end{matrix} right]),

then the formula of DFT is (pmb{y}=text{DFT}_n(pmb a)): (V_npmb{a}=pmb{y}). Specifically, (y_i=sum_{j=0}^{n-1}[V_n]_{ij}a_j=sum_{j=0}^{n-1}omega_n^{ij}a_j).

So, how can we achieve it in (O(nlog n))?

Part 4: FFT

Like DC multiplication, we split the polynomial by parity: (A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)), where (A^{[0]}(x)=sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=sum_{i=0}^{n/2-1}a_{2i+1}x^i).

Then, our evaluation of (A(x)) on (omega_n^0,omega_n^1,cdots,omega_n^{n-1}) becomes

1. Divide-and-conquer: evaluating (A^{[0]}(x)) and (A^{[1]}(x)) on ((omega_n^0)^2,(omega_n^1)^2,cdots,(omega_n^{n-1})^2).

By (text{Lemma 2}), the list ((omega_n^0)^2,(omega_n^1)^2,cdots,(omega_n^{n-1})^2) is exactly a repeated list of (n/2)-roots of unity (Why?)

So we can apply (DFT_{n/2}(pmb a^{[0]})=y^{[0]},DFT_{n/2}(pmb a^{[1]})=pmb y^{[1]}). And the second step is

2. Combining the answers.

As (omega_n^{n/2}=e^{2pi i (n/2)/n}=e^{pi i}=-1) (The beautiful Euler's formula!),

we have (omega_n^{k+n/2}=omega_n^komega_n^{n/2}=-omega_n^k),

so (y_i=y^{[0]}_i+omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-omega_n^i y^{[1]}_i,) for all (i=0,1,cdots,n/2-1).

Specifically, when (n=1), (omega_1^0 a_0=a_0) in the trivial case.

Let's calculate the time complexity

(T(n)=2T(n/2)+f(n)), in which (f(n)=O(n)) is the time used for combination.

Using Master Theorem with (a=2,b=2,log_ba=log_2 2=1), we have (T(n)=O(n^{log_ba}log n)=O(nlog n)). Whooo!

Part 5: Inverse DFT

Don't celebrate too soon, there is still interpolation. Awww

Since (pmb{y}=text{DFT}_n(pmb{a})=V_npmb{a}), we have (pmb{a}=V_n^{-1}pmb{y}), written as (pmb{a}=text{DFT}_n^{-1}(pmb{y})).

(text{Theorem. }) For all (i,j=0,1,cdots,n-1), we have ([V_n^{-1}]_{ij}=omega_n^{-ij}/n).

(textit{Proof. }) We show that (V_n^{-1}V_n=I_n) the identity matrix:

([V_n^{-1}V_n]_{ij}=sum_{k=0}^{n-1}(omega_n^{-ik}/n)omega_n^{kj}=frac{sum_{k=0}^{n-1}omega_n^{-ik}omega_n^{kj}}{n}=frac{sum_{k=0}^{n-1}omega_n^{(j-i)k}}{n})

If (i=j), then (frac{sum_{k=0}^{n-1}omega_n^0}{n}=n/n=1). Otherwise, it is (0/n=0) by (text{Lemma 3}). Therefore, (I_n) is formed. (square)

Next, (pmb{a}=text{DFT}_n^{-1}(pmb{y})=V_n^{-1}pmb{y}), in which (a_i=sum_{j=0}^{n-1}[V_n^{-1}]_{ij}y_j=sum_{j=0}^{n-1}(omega_n^{-ij}/n)y_j=frac{sum_{j=0}^{n-1}omega_n^{-ij}y_j}{n}).

Let's compare: in DFT, (y_i=sum_{j=0}^{n-1}omega_n^{ij}a_j).

Therefore, we can convert DFT to IDFT by simply replacing (omega_n^k) with (omega_n^{-k}) and dividing the final answers by (n).

Part 6: Recursive Implementation

According to the previous text, we just need to modify the code of DC multiplication.

To save memory, we redistribute the coefficients of (A^{[0]}) to the left and (A^{[1]}) to the right.

In the code, o(=omega_n), w(=omega_n^i).

P.s Don't for get (/n) for IDFT. In the code, the +0.5 is used to improve accuracy for integer-coefficient FFT.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx o,w,a[N],b[N],tmp[N],x,y;
int n,m,s;
bool iv;
void fft(Cpx*a,int n);
signed main(){
	Rd(n),Rd(m),s=1<<int(log2(n+m)+1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	fft(a,s),fft(b,s);
	Frn0(i,0,s)a[i]*=b[i];
	iv=1,fft(a,s);
	Frn1(i,0,n+m)wr(a[i].real()/s+0.5),Ps;
	exit(0);
}
void fft(Cpx*a,int n){
	if(n==1)return;
	int n2(n>>1);
	Frn0(i,0,n2)tmp[i]=a[i<<1],tmp[i+n2]=a[i<<1|1];
	copy(tmp,tmp+n,a),fft(a,n2),fft(a+n2,n2);
	o={cos(Pi/n2),(iv?-1:1)*sin(Pi/n2)},w=1;
	Frn0(i,0,n2)x=a[i],y=w*a[i+n2],a[i]=x+y,a[i+n2]=x-y,w*=o;
}

Time complexity: (O(nlog n))

Memory complexity: (O(n))

Results:

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

Not fully AC, as recursive implementation is not fast enough.

Part 6: Iterative Implementation

For (n=deg_A+1,m=deg B+1), let (l=lceillog_2(n+m+1)rceil) and (s=2^l), then (s) is the "(n)" in previous parts.

Similarly, we redistribute the coefficients of (A^{[0]}) to the left and (A^{[1]}) to the right.

Observe the pattern of redistribution in each layer of recursion. Take (s=8) as an example:

0-> 0 1 2 3 4 5 6 7
1-> 0 2 4 6|1 3 5 7
2-> 0 4|2 6|1 5|3 7
end 0|4|2|6|1|5|3|7

Still confused? Write them in base-2:

0-> 000 001 010 011 100 101 110 111
1-> 000 010 100 110|001 011 101 111
2-> 000 100|010 110|001 101|011 111
end 000|100|010|110|001|101|011|111

The base-2 expressions are reversed in the last layer!

A hint of the proof: the redistribution is based on parity, which is equivalent to the last digit of base-2 expressions.

In the code, we use array (r_{0..s-1}) to store the reverse numbers.

Butterfly Operation

It is already written in the code of recursive implementation, but let's clarify that:

Still remember (y_i=y^{[0]}_i+omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-omega_n^i y^{[1]}_i,i=0,1,cdots,n/2-1)?

To save memory, we do not create the array (pmb y), but the combination is done on the original location of the array (pmb a).

After redistribution, we have (a^{[0]}_i=a_i) and (a^{[1]}_i=a_{i+n/2}).

Let (x=a^{[0]}_i=a_i,y=omega_n^i a^{[1]}_i=omega_n^i a_{i+n/2}),

then the result of DFT is simply (a_i=x+y,a_{i+n/2}=x-y)!

With Butterfly Operation, we just need to redistribute the coefficients according to (r), and then combine iteratively to implement FFT.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx a[N],b[N],o,w,x,y;
int n,m,l,s,r[N];
void fft(Cpx*a,bool iv);
signed main(){
	Rd(n),Rd(m),s=1<<(l=log2(n+m)+1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	fft(a,0),fft(b,0);
	Frn0(i,0,s)a[i]*=b[i];
	fft(a,1);
	Frn1(i,0,n+m)wr(a[i].real()+0.5),Ps;
	exit(0);
}
void fft(Cpx*a,bool iv){
	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
		o={cos(Pi/i2),(iv?-1:1)*sin(Pi/i2)};
		for(int j(0);j<s;j+=i){
			w=1;
			Frn0(k,0,i2){
				x=a[j+k],y=w*a[j+k+i2];
				a[j+k]=x+y,a[j+k+i2]=x-y,w*=o;
			}
		}
	}
	if(iv)Frn0(i,0,s)a[i]/=s;
}

Time complexity: (O(nlog n))

Memory complexity: (O(n))

Results:

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

Celebrate


Extension: Number Theoretic Transform

Although FFT has excellent time complexity, inaccuracy will inevitably arise because of the use of complex numbers.

If the polynomial coefficients and results are non-negative integers in a certain range, NTT is a better choice on accuracy and speed.

Prerequisite knowledge:

  FFT absolutely

  Modular arithmetics basics

Primitive roots

Assume that the following calculations are in the context of (bmod P), where (P) is a prime number.

For a positive integer (g), if the list of powers of (g) contains every positive integer (<P), then we call (g) a primitive root (bmod P). (Digression: in Group Theory, the equivalence class of (g) in (Z_p) is a generator of (Z_p^*))

E.g For (P=7) and for all positive integers (<P), we calculate the possibilities of their powers.

1-> {1}
2-> {1,2,4}
3-> {1,2,3,4,5,6}
4-> {1,2,4}
5-> {1,2,3,4,5,6}
6-> {1,6}

Therefore, (3,5) are the primitive roots (bmod 7).

In the code, we commonly use (P=998244353,g=3).

The special property of primitive root (g) is that its powers repeat with period (P-1).

E.g Let (P=7,g=3), then the powers of (g) (beginning with (g^0)) are:(1,3,2,6,4,5,1,3,2,6,4,5,cdots).

This property is very similar to the roots of unity. If we take (n=P-1) and (omega_n=g), then all three lemmas in the FFT part are satisfied.

However, to complete NTT, there is one last step.

The substitute for roots of unity

In FFT, we use (n)-th roots of unity, where (n) is a power of (2).

However, (P-1) is not necessarily (n). Hence, we cannot directly replace (omega_n) with (g).

Now, as the powers of (g) have a period of (P-1),

if we take a factor (k) of (P-1), then the powers of (g^k) have a period of (frac{P-1}{k}). (Why?)

This means that if we take (k=frac{P-1}{n}), then the powers of (g^k) have a period of exactly (n).

But, how can we be sure that (n) is always a factor of (P-1)?

This is why we choose (P=998244353), as (P-1=998244352=2^{23}cdot 7cdot 17), with a high multiplicity of (2).

Therefore, (g^{frac{P-1}{n}}) is just our substitute of (omega_n).

In the code, we use (g^{-1}=332748118) and (cdot s^{-1}) when doing IDFT. Make sure that you include (bmod P) in every operation.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define P (998244353)
#define G (3)
#define Gi (332748118)
#define N (2100000)
int n,m,l,s,r[N],a[N],b[N],o,w,x,y,siv;
int fpw(int a,int p){return p?a>>1?(p&1?a:1)*fpw(a*a%P,p>>1)%P:a:1;}
void ntt(int*a,bool iv);
signed main(){
	Rd(n),Rd(m),siv=fpw(s=1<<(l=log2(n+m)+1),P-2);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	ntt(a,0),ntt(b,0);
	Frn0(i,0,s)a[i]=a[i]*b[i]%P;
	ntt(a,1);
	Frn1(i,0,n+m)wr(a[i]),Ps;
	exit(0);
}
void ntt(int*a,bool iv){
	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
		o=fpw(iv?Gi:G,(P-1)/i);
		for(int j(0);j<s;j+=i){
			w=1;
			Frn0(k,0,i2){
				x=a[j+k],y=w*a[j+k+i2]%P;
				a[j+k]=(x+y)%P,a[j+k+i2]=(x-y+P)%P,w=w*o%P;
			}
		}
	}
	if(iv)Frn0(i,0,s)a[i]=a[i]*siv%P;
}

Time complexity: (O(nlog n))

Memory complexity: (O(n))

Results

Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

No significant improvement in time, but halved the memory cost as int instead of complex is used.


The End:

Translating is sooooo time-consuming...

Another year with Cnblogs! Happy new year!

Thanks for your support! ありがとう!


Reference:

Introduction to Algorithms

自为风月马前卒:快速傅里叶变换(FFT)详解

自为风月马前卒:快速数论变换(NTT)小结

原文链接: https://www.cnblogs.com/BrianPeng/p/15761230.html

欢迎关注

微信关注下方公众号,第一时间获取干货硬货;公众号内回复【pdf】免费获取数百本计算机经典书籍

    Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

原创文章受到原创版权保护。转载请注明出处:https://www.ccppcoding.com/archives/183442

非原创文章文中已经注明原地址,如有侵权,联系删除

关注公众号【高性能架构探索】,第一时间获取最新文章

转载文章受原作者版权保护。转载请注明原作者出处!

(0)
上一篇 2023年2月12日 上午10:23
下一篇 2023年2月12日 上午10:24

相关推荐