Skip to content

Commit

Permalink
Minor: Traits for multilinear polynomials, PCS minor changes (#158)
Browse files Browse the repository at this point in the history
* reference (read-only/write) multilinear polynomial prototype

* full set of testing for both ref and mut_ref mle

* minor, prototype (mut)-multilinear-extension traits for ref-mle-polys

* continue with (mutable)multilinear-extension trait implementation

* pcs interface change to box dyn multilinear-extension, collateral changes included

* minor, add one more trait method of ref to hypercube basis

* minor, use impl to get around lifetime specification in place
  • Loading branch information
tonyfloatersu authored Dec 17, 2024
1 parent 13fd492 commit f73cc68
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 64 deletions.
3 changes: 3 additions & 0 deletions arith/polynomials/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
mod mle;
pub use mle::*;

mod ref_mle;
pub use ref_mle::*;

mod eq;
pub use eq::*;

Expand Down
101 changes: 69 additions & 32 deletions arith/polynomials/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::ops::{Index, IndexMut, Mul};

use arith::Field;
use ark_std::{log2, rand::RngCore};

use crate::EqPolynomial;
use crate::{EqPolynomial, MultilinearExtension, MutableMultilinearExtension};

#[derive(Debug, Clone, Default)]
pub struct MultiLinearPoly<F: Field> {
Expand All @@ -23,33 +25,6 @@ impl<F: Field> MultiLinearPoly<F> {
Self { coeffs: coeff }
}

#[inline]
/// # Safety
/// The returned MultiLinearPoly should not be mutable in order not to
/// mess up the original vector
///
/// PCS may take MultiLinearPoly as input
/// However, it is inefficient to copy the entire vector to create a new MultiLinearPoly
/// Here we introduce a wrap function to reuse the memory space assigned to the original vector
/// This is unsafe, and it is recommended to destroy the wrapper immediately after use
/// Example Usage:
///
/// let vs = vec![F::ONE; 999999];
/// let mle_wrapper = MultiLinearPoly::<F>::wrap_around(&vs); // no extensive memory copy here
///
/// // do something to mle
///
/// mle_wrapper.wrapper_self_destroy() // please do not use drop here, it's incorrect
pub unsafe fn wrap_around(coeffs: &Vec<F>) -> Self {
Self {
coeffs: Vec::from_raw_parts(coeffs.as_ptr() as *mut F, coeffs.len(), coeffs.capacity()),
}
}

pub fn wrapper_self_detroy(self) {
self.coeffs.leak();
}

#[inline]
pub fn get_num_vars(&self) -> usize {
log2(self.coeffs.len()) as usize
Expand Down Expand Up @@ -102,12 +77,12 @@ impl<F: Field> MultiLinearPoly<F> {

/// Evaluate the polynomial at the top variable
#[inline]
pub fn fix_top_variable(&mut self, r: &F) {
pub fn fix_top_variable<AF: Field + Mul<F, Output = F>>(&mut self, r: AF) {
let n = self.coeffs.len() / 2;
let (left, right) = self.coeffs.split_at_mut(n);

left.iter_mut().zip(right.iter()).for_each(|(a, b)| {
*a += *r * (*b - *a);
*a += r * (*b - *a);
});
self.coeffs.truncate(n);
}
Expand All @@ -116,12 +91,12 @@ impl<F: Field> MultiLinearPoly<F> {
/// Evaluate the polynomial at a set of variables, from bottom to top
/// This is equivalent to `evaluate` when partial_point.len() = nv
#[inline]
pub fn fix_variables(&mut self, partial_point: &[F]) {
pub fn fix_variables<AF: Field + Mul<F, Output = F>>(&mut self, partial_point: &[AF]) {
// evaluate single variable of partial point from left to right
partial_point
.iter()
.rev() // need to reverse the order of the point
.for_each(|point| self.fix_top_variable(point));
.for_each(|point| self.fix_top_variable(*point));
}

/// Jolt's implementation
Expand Down Expand Up @@ -168,3 +143,65 @@ impl<F: Field> MultiLinearPoly<F> {
}
}
}

impl<F: Field> Index<usize> for MultiLinearPoly<F> {
type Output = F;

fn index(&self, index: usize) -> &Self::Output {
&self.coeffs[index]
}
}

impl<F: Field> MultilinearExtension<F> for MultiLinearPoly<F> {
fn num_vars(&self) -> usize {
self.get_num_vars()
}

fn hypercube_basis(&self) -> Vec<F> {
self.coeffs.clone()
}

fn hypercube_basis_ref(&self) -> &Vec<F> {
&self.coeffs
}

fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F {
Self::evaluate_with_buffer(&self.coeffs, point, scratch)
}

fn interpolate_over_hypercube(&self) -> Vec<F> {
self.interpolate_over_hypercube()
}
}

impl<F: Field> IndexMut<usize> for MultiLinearPoly<F> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.coeffs[index]
}
}

impl<F: Field> MutableMultilinearExtension<F> for MultiLinearPoly<F> {
fn fix_top_variable<AF: Field + std::ops::Mul<F, Output = F>>(&mut self, r: AF) {
self.fix_top_variable(r)
}

fn fix_variables<AF: Field + std::ops::Mul<F, Output = F>>(&mut self, vars: &[AF]) {
self.fix_variables(vars)
}

fn interpolate_over_hypercube_in_place(&mut self) {
let num_vars = self.num_vars();
for i in 1..=num_vars {
let chunk_size = 1 << i;

self.coeffs.chunks_mut(chunk_size).for_each(|chunk| {
let half_chunk = chunk_size >> 1;
let (left, right) = chunk.split_at_mut(half_chunk);
right
.iter_mut()
.zip(left.iter())
.for_each(|(a, b)| *a -= *b);
})
}
}
}
160 changes: 160 additions & 0 deletions arith/polynomials/src/ref_mle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use std::ops::{Index, IndexMut, Mul};

use arith::Field;

use crate::MultiLinearPoly;

pub trait MultilinearExtension<F: Field>: Index<usize, Output = F> {
fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F;

fn num_vars(&self) -> usize;

fn hypercube_size(&self) -> usize {
1 << self.num_vars()
}

fn hypercube_basis(&self) -> Vec<F>;

fn hypercube_basis_ref(&self) -> &Vec<F>;

fn interpolate_over_hypercube(&self) -> Vec<F>;
}

#[derive(Debug, Clone)]
pub struct RefMultiLinearPoly<'ref_life, F: Field> {
pub coeffs: &'ref_life Vec<F>,
}

impl<'ref_life, 'outer: 'ref_life, F: Field> RefMultiLinearPoly<'ref_life, F> {
#[inline]
pub fn from_ref(evals: &'outer Vec<F>) -> Self {
Self { coeffs: evals }
}
}

impl<'a, F: Field> Index<usize> for RefMultiLinearPoly<'a, F> {
type Output = F;

fn index(&self, index: usize) -> &Self::Output {
assert!(index < self.hypercube_size());
&self.coeffs[index]
}
}

impl<'a, F: Field> MultilinearExtension<F> for RefMultiLinearPoly<'a, F> {
fn num_vars(&self) -> usize {
assert!(self.coeffs.len().is_power_of_two());
self.coeffs.len().ilog2() as usize
}

fn hypercube_basis(&self) -> Vec<F> {
self.coeffs.clone()
}

fn hypercube_basis_ref(&self) -> &Vec<F> {
self.coeffs
}

fn interpolate_over_hypercube(&self) -> Vec<F> {
MultiLinearPoly::interpolate_over_hypercube_impl(self.coeffs)
}

fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F {
MultiLinearPoly::evaluate_with_buffer(self.coeffs, point, scratch)
}
}

pub trait MutableMultilinearExtension<F: Field>:
MultilinearExtension<F> + IndexMut<usize, Output = F>
{
fn fix_top_variable<AF: Field + Mul<F, Output = F>>(&mut self, r: AF);

fn fix_variables<AF: Field + Mul<F, Output = F>>(&mut self, vars: &[AF]);

fn interpolate_over_hypercube_in_place(&mut self);
}

#[derive(Debug)]
pub struct MutRefMultiLinearPoly<'ref_life, F: Field> {
pub coeffs: &'ref_life mut Vec<F>,
}

impl<'ref_life, 'outer_mut: 'ref_life, F: Field> MutRefMultiLinearPoly<'ref_life, F> {
#[inline]
pub fn from_ref(evals: &'outer_mut mut Vec<F>) -> Self {
Self { coeffs: evals }
}
}

impl<'a, F: Field> Index<usize> for MutRefMultiLinearPoly<'a, F> {
type Output = F;

fn index(&self, index: usize) -> &Self::Output {
assert!(index < self.hypercube_size());
&self.coeffs[index]
}
}

impl<'a, F: Field> IndexMut<usize> for MutRefMultiLinearPoly<'a, F> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
assert!(index < self.hypercube_size());
&mut self.coeffs[index]
}
}

impl<'a, F: Field> MultilinearExtension<F> for MutRefMultiLinearPoly<'a, F> {
fn num_vars(&self) -> usize {
assert!(self.coeffs.len().is_power_of_two());
self.coeffs.len().ilog2() as usize
}

fn hypercube_basis(&self) -> Vec<F> {
self.coeffs.clone()
}

fn hypercube_basis_ref(&self) -> &Vec<F> {
self.coeffs
}

fn interpolate_over_hypercube(&self) -> Vec<F> {
MultiLinearPoly::interpolate_over_hypercube_impl(self.coeffs)
}

fn evaluate_with_buffer(&self, point: &[F], scratch: &mut [F]) -> F {
MultiLinearPoly::evaluate_with_buffer(self.coeffs, point, scratch)
}
}

impl<'a, F: Field> MutableMultilinearExtension<F> for MutRefMultiLinearPoly<'a, F> {
fn fix_top_variable<AF: Field + Mul<F, Output = F>>(&mut self, r: AF) {
let n = self.hypercube_size() / 2;
let (left, right) = self.coeffs.split_at_mut(n);

left.iter_mut().zip(right.iter()).for_each(|(a, b)| {
*a += r * (*b - *a);
});
self.coeffs.truncate(n)
}

fn fix_variables<AF: Field + Mul<F, Output = F>>(&mut self, vars: &[AF]) {
// evaluate single variable of partial point from left to right
// need to reverse the order of the point
vars.iter().rev().for_each(|p| self.fix_top_variable(*p))
}

fn interpolate_over_hypercube_in_place(&mut self) {
let num_vars = self.num_vars();
for i in 1..=num_vars {
let chunk_size = 1 << i;

self.coeffs.chunks_mut(chunk_size).for_each(|chunk| {
let half_chunk = chunk_size >> 1;
let (left, right) = chunk.split_at_mut(half_chunk);
right
.iter_mut()
.zip(left.iter())
.for_each(|(a, b)| *a -= *b);
})
}
}
}
45 changes: 45 additions & 0 deletions arith/polynomials/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,51 @@ fn test_eq_xr() {
}
}

#[test]
fn test_ref_multilinear_poly() {
let mut rng = test_rng();
for nv in 4..=10 {
let es_len = 1 << nv;
let es: Vec<Fr> = (0..es_len).map(|_| Fr::random_unsafe(&mut rng)).collect();
let point: Vec<Fr> = (0..nv).map(|_| Fr::random_unsafe(&mut rng)).collect();
let mut scratch = vec![Fr::ZERO; es_len];

let mle_from_ref = RefMultiLinearPoly::<Fr>::from_ref(&es);

let actual_eval = mle_from_ref.evaluate_with_buffer(&point, &mut scratch);
let expect_eval = MultiLinearPoly::evaluate_with_buffer(&es, &point, &mut scratch);

drop(mle_from_ref);

assert_eq!(actual_eval, expect_eval);

drop(es);
}
}

#[test]
fn test_mut_ref_multilinear_poly() {
let mut rng = test_rng();
for nv in 4..=10 {
let es_len = 1 << nv;
let mut es: Vec<Fr> = (0..es_len).map(|_| Fr::random_unsafe(&mut rng)).collect();
let es_cloned = es.clone();
let point: Vec<Fr> = (0..nv).map(|_| Fr::random_unsafe(&mut rng)).collect();
let mut scratch = vec![Fr::ZERO; es_len];

let mut mle_from_mut_ref = MutRefMultiLinearPoly::<Fr>::from_ref(&mut es);

mle_from_mut_ref.fix_variables(&point);
let expect_eval = MultiLinearPoly::evaluate_with_buffer(&es_cloned, &point, &mut scratch);

drop(mle_from_mut_ref);

assert_eq!(es[0], expect_eval);

drop(es);
}
}

/// Naive method to build eq(x, r).
/// Only used for testing purpose.
// Evaluate
Expand Down
Loading

0 comments on commit f73cc68

Please sign in to comment.