-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minor: Traits for multilinear polynomials, PCS minor changes (#158)
* reference (read-only/write) multilinear polynomial prototype * full set of testing for both ref and mut_ref mle * minor, prototype (mut)-multilinear-extension traits for ref-mle-polys * continue with (mutable)multilinear-extension trait implementation * pcs interface change to box dyn multilinear-extension, collateral changes included * minor, add one more trait method of ref to hypercube basis * minor, use impl to get around lifetime specification in place
- Loading branch information
1 parent
13fd492
commit f73cc68
Showing
9 changed files
with
312 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
mod mle; | ||
pub use mle::*; | ||
|
||
mod ref_mle; | ||
pub use ref_mle::*; | ||
|
||
mod eq; | ||
pub use eq::*; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
use std::ops::{Index, IndexMut, Mul}; | ||
|
||
use arith::Field; | ||
|
||
use crate::MultiLinearPoly; | ||
|
||
pub trait MultilinearExtension<F: Field>: Index<usize, Output = F> { | ||
fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F; | ||
|
||
fn num_vars(&self) -> usize; | ||
|
||
fn hypercube_size(&self) -> usize { | ||
1 << self.num_vars() | ||
} | ||
|
||
fn hypercube_basis(&self) -> Vec<F>; | ||
|
||
fn hypercube_basis_ref(&self) -> &Vec<F>; | ||
|
||
fn interpolate_over_hypercube(&self) -> Vec<F>; | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct RefMultiLinearPoly<'ref_life, F: Field> { | ||
pub coeffs: &'ref_life Vec<F>, | ||
} | ||
|
||
impl<'ref_life, 'outer: 'ref_life, F: Field> RefMultiLinearPoly<'ref_life, F> { | ||
#[inline] | ||
pub fn from_ref(evals: &'outer Vec<F>) -> Self { | ||
Self { coeffs: evals } | ||
} | ||
} | ||
|
||
impl<'a, F: Field> Index<usize> for RefMultiLinearPoly<'a, F> { | ||
type Output = F; | ||
|
||
fn index(&self, index: usize) -> &Self::Output { | ||
assert!(index < self.hypercube_size()); | ||
&self.coeffs[index] | ||
} | ||
} | ||
|
||
impl<'a, F: Field> MultilinearExtension<F> for RefMultiLinearPoly<'a, F> { | ||
fn num_vars(&self) -> usize { | ||
assert!(self.coeffs.len().is_power_of_two()); | ||
self.coeffs.len().ilog2() as usize | ||
} | ||
|
||
fn hypercube_basis(&self) -> Vec<F> { | ||
self.coeffs.clone() | ||
} | ||
|
||
fn hypercube_basis_ref(&self) -> &Vec<F> { | ||
self.coeffs | ||
} | ||
|
||
fn interpolate_over_hypercube(&self) -> Vec<F> { | ||
MultiLinearPoly::interpolate_over_hypercube_impl(self.coeffs) | ||
} | ||
|
||
fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F { | ||
MultiLinearPoly::evaluate_with_buffer(self.coeffs, point, scratch) | ||
} | ||
} | ||
|
||
pub trait MutableMultilinearExtension<F: Field>: | ||
MultilinearExtension<F> + IndexMut<usize, Output = F> | ||
{ | ||
fn fix_top_variable<AF: Field + Mul<F, Output = F>>(&mut self, r: AF); | ||
|
||
fn fix_variables<AF: Field + Mul<F, Output = F>>(&mut self, vars: &[AF]); | ||
|
||
fn interpolate_over_hypercube_in_place(&mut self); | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct MutRefMultiLinearPoly<'ref_life, F: Field> { | ||
pub coeffs: &'ref_life mut Vec<F>, | ||
} | ||
|
||
impl<'ref_life, 'outer_mut: 'ref_life, F: Field> MutRefMultiLinearPoly<'ref_life, F> { | ||
#[inline] | ||
pub fn from_ref(evals: &'outer_mut mut Vec<F>) -> Self { | ||
Self { coeffs: evals } | ||
} | ||
} | ||
|
||
impl<'a, F: Field> Index<usize> for MutRefMultiLinearPoly<'a, F> { | ||
type Output = F; | ||
|
||
fn index(&self, index: usize) -> &Self::Output { | ||
assert!(index < self.hypercube_size()); | ||
&self.coeffs[index] | ||
} | ||
} | ||
|
||
impl<'a, F: Field> IndexMut<usize> for MutRefMultiLinearPoly<'a, F> { | ||
fn index_mut(&mut self, index: usize) -> &mut Self::Output { | ||
assert!(index < self.hypercube_size()); | ||
&mut self.coeffs[index] | ||
} | ||
} | ||
|
||
impl<'a, F: Field> MultilinearExtension<F> for MutRefMultiLinearPoly<'a, F> { | ||
fn num_vars(&self) -> usize { | ||
assert!(self.coeffs.len().is_power_of_two()); | ||
self.coeffs.len().ilog2() as usize | ||
} | ||
|
||
fn hypercube_basis(&self) -> Vec<F> { | ||
self.coeffs.clone() | ||
} | ||
|
||
fn hypercube_basis_ref(&self) -> &Vec<F> { | ||
self.coeffs | ||
} | ||
|
||
fn interpolate_over_hypercube(&self) -> Vec<F> { | ||
MultiLinearPoly::interpolate_over_hypercube_impl(self.coeffs) | ||
} | ||
|
||
fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F { | ||
MultiLinearPoly::evaluate_with_buffer(self.coeffs, point, scratch) | ||
} | ||
} | ||
|
||
impl<'a, F: Field> MutableMultilinearExtension<F> for MutRefMultiLinearPoly<'a, F> { | ||
fn fix_top_variable<AF: Field + Mul<F, Output = F>>(&mut self, r: AF) { | ||
let n = self.hypercube_size() / 2; | ||
let (left, right) = self.coeffs.split_at_mut(n); | ||
|
||
left.iter_mut().zip(right.iter()).for_each(|(a, b)| { | ||
*a += r * (*b - *a); | ||
}); | ||
self.coeffs.truncate(n) | ||
} | ||
|
||
fn fix_variables<AF: Field + Mul<F, Output = F>>(&mut self, vars: &[AF]) { | ||
// evaluate single variable of partial point from left to right | ||
// need to reverse the order of the point | ||
vars.iter().rev().for_each(|p| self.fix_top_variable(*p)) | ||
} | ||
|
||
fn interpolate_over_hypercube_in_place(&mut self) { | ||
let num_vars = self.num_vars(); | ||
for i in 1..=num_vars { | ||
let chunk_size = 1 << i; | ||
|
||
self.coeffs.chunks_mut(chunk_size).for_each(|chunk| { | ||
let half_chunk = chunk_size >> 1; | ||
let (left, right) = chunk.split_at_mut(half_chunk); | ||
right | ||
.iter_mut() | ||
.zip(left.iter()) | ||
.for_each(|(a, b)| *a -= *b); | ||
}) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.