Skip to content

Commit

Permalink
fft&ifft must be public, i think
Browse files Browse the repository at this point in the history
  • Loading branch information
ngtkana committed Sep 24, 2023
1 parent 80f58c1 commit ca42673
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 39 deletions.
138 changes: 99 additions & 39 deletions libs/fp2/src/fourier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type F2 = Fp<P2>;
type F3 = Fp<P3>;

/// Multiplies two polynomials.
///
/// # Examples
/// ```
/// use fp2::fp;
Expand All @@ -34,12 +35,12 @@ where
let len = n.next_power_of_two();
a.resize(len, Fp::new(0));
b.resize(len, Fp::new(0));
fft::<P>(&mut a);
fft::<P>(&mut b);
fft(&mut a);
fft(&mut b);
for (a, b) in a.iter_mut().zip(b.iter()) {
*a *= *b;
}
ifft::<P>(&mut a);
ifft(&mut a);
a.truncate(n);
a
}
Expand All @@ -65,18 +66,48 @@ pub fn any_mod_fps_mul<const P: u64>(a: &[Fp<P>], b: &[Fp<P>]) -> Vec<Fp<P>> {
.collect::<Vec<_>>()
}

fn fft<const P: u64>(a: &mut [Fp<P>])
/// Fast Fourier transform.
/// # Requirements
///
/// - The length $n$ of $f$ is a power of two.
/// - $n | (p - 1)$
///
/// Especially, if $p = 998244353$, $n \leq 2^{23}$ must hold.
///
/// # Replaced by
///
/// $f(1), f(-1), f(i), f(-i), f(e^{\pi/2}), f(e^{5\pi/2}), \dots$
///
/// # Examples
///
/// ```
/// use fp2::fft;
/// use fp2::fp;
/// use fp2::Fp;
/// type F = Fp<998244353>;
/// let mut f: Vec<F> = vec![fp!(1000), fp!(100), fp!(10), fp!(1)];
/// fft(&mut f);
/// let i = fp!(3).pow(998244352 / 4);
/// assert_eq!(f, vec![
/// fp!(1111),
/// fp!(909),
/// fp!(990) + i * fp!(99),
/// fp!(990) - i * fp!(99)
/// ]);
/// ```
pub fn fft<const P: u64>(f: &mut [Fp<P>])
where
(): PrimitiveRoot<P>,
{
let n = a.len();
let n = f.len();
assert!(n.is_power_of_two());
let mut root = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / a.len() as u64);
assert!((P - 1) % n as u64 == 0);
let mut root = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / f.len() as u64);
let fourth = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / 4);
let mut fft_len = n;
while 4 <= fft_len {
let quarter = fft_len / 4;
for a in a.chunks_mut(fft_len) {
for f in f.chunks_mut(fft_len) {
let mut c = Fp::new(1);
for (((i, j), k), l) in (0..)
.zip(quarter..)
Expand All @@ -85,14 +116,14 @@ where
.take(quarter)
{
let c2 = c * c;
let x = a[i] + a[k];
let y = a[j] + a[l];
let z = a[i] - a[k];
let w = fourth * (a[j] - a[l]);
a[i] = x + y;
a[j] = c2 * (x - y);
a[k] = c * (z + w);
a[l] = c2 * c * (z - w);
let x = f[i] + f[k];
let y = f[j] + f[l];
let z = f[i] - f[k];
let w = fourth * (f[j] - f[l]);
f[i] = x + y;
f[j] = c2 * (x - y);
f[k] = c * (z + w);
f[l] = c2 * c * (z - w);
c *= root;
}
}
Expand All @@ -101,40 +132,69 @@ where
fft_len = quarter;
}
if fft_len == 2 {
for a in a.chunks_mut(2) {
let x = a[0];
let y = a[1];
a[0] = x + y;
a[1] = x - y;
for f in f.chunks_mut(2) {
let x = f[0];
let y = f[1];
f[0] = x + y;
f[1] = x - y;
}
}
}
fn ifft<const P: u64>(a: &mut [Fp<P>])
/// Inverse fast Fourier transform.
///
/// # Requirements
///
/// - The length $n$ of $f$ is a power of two.
/// - $n | (p - 1)$
///
/// Especially, if $p = 998244353$, $n \leq 2^{23}$ must hold.
///
/// # Replaced by
///
/// Exacly the inverse of [`fft`].
///
/// # Examples
/// ```
/// use fp2::fp;
/// use fp2::ifft;
/// use fp2::Fp;
/// type F = Fp<998244353>;
/// let i = fp!(3).pow(998244352 / 4);
/// let mut f: Vec<F> = vec![
/// fp!(1111),
/// fp!(909),
/// fp!(990) + i * fp!(99),
/// fp!(990) - i * fp!(99),
/// ];
/// ifft(&mut f);
/// assert_eq!(f, vec![fp!(1000), fp!(100), fp!(10), fp!(1)]);
/// ```
pub fn ifft<const P: u64>(f: &mut [Fp<P>])
where
(): PrimitiveRoot<P>,
{
let n = a.len();
let n = f.len();
assert!(n.is_power_of_two());
let root = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / a.len() as u64);
let root = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / f.len() as u64);
let mut roots = std::iter::successors(Some(root.inv()), |x| Some(x * x))
.take(n.trailing_zeros() as usize + 1)
.collect::<Vec<_>>();
roots.reverse();
let fourth = <() as PrimitiveRoot<P>>::VALUE.pow((P - 1) / 4).inv();
let mut quarter = 1_usize;
if n.trailing_zeros() % 2 == 1 {
for a in a.chunks_mut(2) {
let x = a[0];
let y = a[1];
a[0] = x + y;
a[1] = x - y;
for f in f.chunks_mut(2) {
let x = f[0];
let y = f[1];
f[0] = x + y;
f[1] = x - y;
}
quarter = 2;
}
while quarter != n {
let fft_len = quarter * 4;
let root = roots[fft_len.trailing_zeros() as usize];
for a in a.chunks_mut(fft_len) {
for f in f.chunks_mut(fft_len) {
let mut c = Fp::new(1);
for (((i, j), k), l) in (0..)
.zip(quarter..)
Expand All @@ -143,21 +203,21 @@ where
.take(quarter)
{
let c2 = c * c;
let x = a[i] + c2 * a[j];
let y = a[i] - c2 * a[j];
let z = c * (a[k] + c2 * a[l]);
let w = fourth * c * (a[k] - c2 * a[l]);
a[i] = x + z;
a[j] = y + w;
a[k] = x - z;
a[l] = y - w;
let x = f[i] + c2 * f[j];
let y = f[i] - c2 * f[j];
let z = c * (f[k] + c2 * f[l]);
let w = fourth * c * (f[k] - c2 * f[l]);
f[i] = x + z;
f[j] = y + w;
f[k] = x - z;
f[l] = y - w;
c *= root;
}
}
quarter = fft_len;
}
let d = Fp::from(a.len()).inv();
a.iter_mut().for_each(|x| *x *= d);
let d = Fp::from(f.len()).inv();
f.iter_mut().for_each(|x| *x *= d);
}

/// Restore the original value from the remainder of the division by `P1`, `P2`, and `P3`.
Expand Down
2 changes: 2 additions & 0 deletions libs/fp2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ mod fourier;
use ext_gcd::mod_inv;
pub use factorial::Factorial;
pub use fourier::any_mod_fps_mul;
pub use fourier::fft;
pub use fourier::fps_mul;
pub use fourier::ifft;
use std::iter::Product;
use std::iter::Sum;
use std::ops::Add;
Expand Down

0 comments on commit ca42673

Please sign in to comment.