diff --git a/libs/fp2/src/fourier.rs b/libs/fp2/src/fourier.rs index 262f37d9..c72dee2f 100644 --- a/libs/fp2/src/fourier.rs +++ b/libs/fp2/src/fourier.rs @@ -10,6 +10,7 @@ type F2 = Fp; type F3 = Fp; /// Multiplies two polynomials. +/// /// # Examples /// ``` /// use fp2::fp; @@ -34,12 +35,12 @@ where let len = n.next_power_of_two(); a.resize(len, Fp::new(0)); b.resize(len, Fp::new(0)); - fft::

(&mut a); - fft::

(&mut b); + fft(&mut a); + fft(&mut b); for (a, b) in a.iter_mut().zip(b.iter()) { *a *= *b; } - ifft::

(&mut a); + ifft(&mut a); a.truncate(n); a } @@ -65,18 +66,48 @@ pub fn any_mod_fps_mul(a: &[Fp

], b: &[Fp

]) -> Vec> { .collect::>() } -fn fft(a: &mut [Fp

]) +/// Fast Fourier transform. +/// # Requirements +/// +/// - The length $n$ of $f$ is a power of two. +/// - $n | (p - 1)$ +/// +/// Especially, if $p = 998244353$, $n \leq 2^{23}$ must hold. +/// +/// # Replaced by +/// +/// $f(1), f(-1), f(i), f(-i), f(e^{\pi/2}), f(e^{5\pi/2}), \dots$ +/// +/// # Examples +/// +/// ``` +/// use fp2::fft; +/// use fp2::fp; +/// use fp2::Fp; +/// type F = Fp<998244353>; +/// let mut f: Vec = vec![fp!(1000), fp!(100), fp!(10), fp!(1)]; +/// fft(&mut f); +/// let i = fp!(3).pow(998244352 / 4); +/// assert_eq!(f, vec![ +/// fp!(1111), +/// fp!(909), +/// fp!(990) + i * fp!(99), +/// fp!(990) - i * fp!(99) +/// ]); +/// ``` +pub fn fft(f: &mut [Fp

]) where (): PrimitiveRoot

, { - let n = a.len(); + let n = f.len(); assert!(n.is_power_of_two()); - let mut root = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / a.len() as u64); + assert!((P - 1) % n as u64 == 0); + let mut root = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / f.len() as u64); let fourth = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / 4); let mut fft_len = n; while 4 <= fft_len { let quarter = fft_len / 4; - for a in a.chunks_mut(fft_len) { + for f in f.chunks_mut(fft_len) { let mut c = Fp::new(1); for (((i, j), k), l) in (0..) .zip(quarter..) @@ -85,14 +116,14 @@ where .take(quarter) { let c2 = c * c; - let x = a[i] + a[k]; - let y = a[j] + a[l]; - let z = a[i] - a[k]; - let w = fourth * (a[j] - a[l]); - a[i] = x + y; - a[j] = c2 * (x - y); - a[k] = c * (z + w); - a[l] = c2 * c * (z - w); + let x = f[i] + f[k]; + let y = f[j] + f[l]; + let z = f[i] - f[k]; + let w = fourth * (f[j] - f[l]); + f[i] = x + y; + f[j] = c2 * (x - y); + f[k] = c * (z + w); + f[l] = c2 * c * (z - w); c *= root; } } @@ -101,21 +132,50 @@ where fft_len = quarter; } if fft_len == 2 { - for a in a.chunks_mut(2) { - let x = a[0]; - let y = a[1]; - a[0] = x + y; - a[1] = x - y; + for f in f.chunks_mut(2) { + let x = f[0]; + let y = f[1]; + f[0] = x + y; + f[1] = x - y; } } } -fn ifft(a: &mut [Fp

]) +/// Inverse fast Fourier transform. +/// +/// # Requirements +/// +/// - The length $n$ of $f$ is a power of two. +/// - $n | (p - 1)$ +/// +/// Especially, if $p = 998244353$, $n \leq 2^{23}$ must hold. +/// +/// # Replaced by +/// +/// Exacly the inverse of [`fft`]. +/// +/// # Examples +/// ``` +/// use fp2::fp; +/// use fp2::ifft; +/// use fp2::Fp; +/// type F = Fp<998244353>; +/// let i = fp!(3).pow(998244352 / 4); +/// let mut f: Vec = vec![ +/// fp!(1111), +/// fp!(909), +/// fp!(990) + i * fp!(99), +/// fp!(990) - i * fp!(99), +/// ]; +/// ifft(&mut f); +/// assert_eq!(f, vec![fp!(1000), fp!(100), fp!(10), fp!(1)]); +/// ``` +pub fn ifft(f: &mut [Fp

]) where (): PrimitiveRoot

, { - let n = a.len(); + let n = f.len(); assert!(n.is_power_of_two()); - let root = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / a.len() as u64); + let root = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / f.len() as u64); let mut roots = std::iter::successors(Some(root.inv()), |x| Some(x * x)) .take(n.trailing_zeros() as usize + 1) .collect::>(); @@ -123,18 +183,18 @@ where let fourth = <() as PrimitiveRoot

>::VALUE.pow((P - 1) / 4).inv(); let mut quarter = 1_usize; if n.trailing_zeros() % 2 == 1 { - for a in a.chunks_mut(2) { - let x = a[0]; - let y = a[1]; - a[0] = x + y; - a[1] = x - y; + for f in f.chunks_mut(2) { + let x = f[0]; + let y = f[1]; + f[0] = x + y; + f[1] = x - y; } quarter = 2; } while quarter != n { let fft_len = quarter * 4; let root = roots[fft_len.trailing_zeros() as usize]; - for a in a.chunks_mut(fft_len) { + for f in f.chunks_mut(fft_len) { let mut c = Fp::new(1); for (((i, j), k), l) in (0..) .zip(quarter..) @@ -143,21 +203,21 @@ where .take(quarter) { let c2 = c * c; - let x = a[i] + c2 * a[j]; - let y = a[i] - c2 * a[j]; - let z = c * (a[k] + c2 * a[l]); - let w = fourth * c * (a[k] - c2 * a[l]); - a[i] = x + z; - a[j] = y + w; - a[k] = x - z; - a[l] = y - w; + let x = f[i] + c2 * f[j]; + let y = f[i] - c2 * f[j]; + let z = c * (f[k] + c2 * f[l]); + let w = fourth * c * (f[k] - c2 * f[l]); + f[i] = x + z; + f[j] = y + w; + f[k] = x - z; + f[l] = y - w; c *= root; } } quarter = fft_len; } - let d = Fp::from(a.len()).inv(); - a.iter_mut().for_each(|x| *x *= d); + let d = Fp::from(f.len()).inv(); + f.iter_mut().for_each(|x| *x *= d); } /// Restore the original value from the remainder of the division by `P1`, `P2`, and `P3`. diff --git a/libs/fp2/src/lib.rs b/libs/fp2/src/lib.rs index ed1925c2..9654cfd4 100644 --- a/libs/fp2/src/lib.rs +++ b/libs/fp2/src/lib.rs @@ -42,7 +42,9 @@ mod fourier; use ext_gcd::mod_inv; pub use factorial::Factorial; pub use fourier::any_mod_fps_mul; +pub use fourier::fft; pub use fourier::fps_mul; +pub use fourier::ifft; use std::iter::Product; use std::iter::Sum; use std::ops::Add;