Skip to content

Commit

Permalink
Merge pull request #2492 from o1-labs/dw/implement-from-expr
Browse files Browse the repository at this point in the history
MVPoly: implement from Expr
  • Loading branch information
dannywillems authored Aug 28, 2024
2 parents 73439f2 + 112db17 commit 99ad37d
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mvpoly/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ path = "src/lib.rs"

[dependencies]
ark-ff.workspace = true
kimchi.workspace = true
log.workspace = true
num-integer.workspace = true
o1-utils.workspace = true
Expand Down
172 changes: 172 additions & 0 deletions mvpoly/src/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ use std::{
};

use ark_ff::{One, PrimeField, Zero};
use kimchi::circuits::{
expr::{
ChallengeTerm, ConstantExpr, ConstantExprInner, ConstantTerm, Expr, ExprInner, Operations,
Variable,
},
gate::CurrOrNext,
};
use num_integer::binomial;
use o1_utils::FieldHelpers;
use rand::RngCore;
Expand Down Expand Up @@ -331,6 +338,37 @@ impl<F: PrimeField, const N: usize, const D: usize> Dense<F, N, D> {
let coeffs = self.coeff.iter().map(|coef| *coef * c).collect();
Self::from_coeffs(coeffs)
}

/// Evaluate the polynomial at the vector point `x`.
///
/// This is a dummy implementation. A cache can be used for the monomials to
/// speed up the computation.
pub fn eval(&self, x: &[F; N]) -> F {
let mut prime_gen = PrimeNumberGenerator::new();
let primes = prime_gen.get_first_nth_primes(N);
self.coeff
.iter()
.enumerate()
.fold(F::zero(), |acc, (i, c)| {
if i == 0 {
acc + c
} else {
let normalized_index = self.normalized_indices[i];
// IMPROVEME: we should keep the prime decomposition somewhere.
// It can be precomputed for a few multi-variate polynomials
// vector space
let prime_decomposition = naive_prime_factors(normalized_index, &mut prime_gen);
let mut monomial = F::one();
prime_decomposition.iter().for_each(|(p, d)| {
// IMPROVEME: we should keep the inverse indices
let inv_p = primes.iter().position(|&x| x == *p).unwrap();
let x_p = x[inv_p].pow([*d as u64]);
monomial *= x_p;
});
acc + *c * monomial
}
})
}
}

impl<F: PrimeField, const N: usize, const D: usize> Default for Dense<F, N, D> {
Expand Down Expand Up @@ -554,4 +592,138 @@ impl<F: PrimeField, const N: usize, const D: usize> From<F> for Dense<F, N, D> {
}
}

impl<F: PrimeField, const N: usize, const D: usize> From<ConstantExprInner<F>> for Dense<F, N, D> {
fn from(expr: ConstantExprInner<F>) -> Self {
match expr {
// The unimplemented methods might be implemented in the future if
// we move to the challenge into the type of the constant
// terms/expressions
// Unrolling for visibility
ConstantExprInner::Challenge(ChallengeTerm::Alpha) => {
unimplemented!("The challenge alpha is not supposed to be used in this context")
}
ConstantExprInner::Challenge(ChallengeTerm::Beta) => {
unimplemented!("The challenge beta is not supposed to be used in this context")
}
ConstantExprInner::Challenge(ChallengeTerm::Gamma) => {
unimplemented!("The challenge gamma is not supposed to be used in this context")
}
ConstantExprInner::Challenge(ChallengeTerm::JointCombiner) => {
unimplemented!(
"The challenge joint combiner is not supposed to be used in this context"
)
}
ConstantExprInner::Constant(ConstantTerm::EndoCoefficient) => {
unimplemented!(
"The constant EndoCoefficient is not supposed to be used in this context"
)
}
ConstantExprInner::Constant(ConstantTerm::Mds {
row: _row,
col: _col,
}) => {
unimplemented!("The constant Mds is not supposed to be used in this context")
}
ConstantExprInner::Constant(ConstantTerm::Literal(c)) => Dense::from(c),
}
}
}

impl<F: PrimeField, const N: usize, const D: usize> From<Operations<ConstantExprInner<F>>>
for Dense<F, N, D>
{
fn from(op: Operations<ConstantExprInner<F>>) -> Self {
use kimchi::circuits::expr::Operations::*;
match op {
Atom(op_const) => Self::from(op_const),
Add(c1, c2) => Self::from(*c1) + Self::from(*c2),
Sub(c1, c2) => Self::from(*c1) - Self::from(*c2),
Mul(c1, c2) => Self::from(*c1) * Self::from(*c2),
Square(c) => Self::from(*c.clone()) * Self::from(*c),
Double(c1) => Self::from(*c1).double(),
Pow(c, e) => {
// FIXME: dummy implementation
let p = Dense::from(*c);
let mut result = p.clone();
for _ in 0..e {
result = result.clone() * p.clone();
}
result
}
Cache(_c, _) => {
unimplemented!("The module prime is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor")
}
IfFeature(_c, _t, _f) => {
unimplemented!("The module prime is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor")
}
}
}
}

impl<Column: Into<usize>, F: PrimeField, const N: usize, const D: usize>
From<Expr<ConstantExpr<F>, Column>> for Dense<F, N, D>
{
fn from(expr: Expr<ConstantExpr<F>, Column>) -> Self {
// This is a dummy implementation
// TODO: Implement the actual conversion logic
use kimchi::circuits::expr::Operations::*;

match expr {
Atom(op_const) => {
match op_const {
ExprInner::UnnormalizedLagrangeBasis(_) => {
unimplemented!("Not used in this context")
}
ExprInner::VanishesOnZeroKnowledgeAndPreviousRows => {
unimplemented!("Not used in this context")
}
ExprInner::Constant(c) => Self::from(c),
ExprInner::Cell(Variable { col, row }) => {
assert_eq!(row, CurrOrNext::Curr, "Only current row is supported for now. You cannot reference the next row");
Self::from_variable(col)
}
}
}
Add(e1, e2) => {
let p1 = Dense::from(*e1);
let p2 = Dense::from(*e2);
p1 + p2
}
Sub(e1, e2) => {
let p1 = Dense::from(*e1);
let p2 = Dense::from(*e2);
p1 - p2
}
Mul(e1, e2) => {
let p1 = Dense::from(*e1);
let p2 = Dense::from(*e2);
p1 * p2
}
Double(p) => {
let p = Dense::from(*p);
p.double()
}
Square(p) => {
let p = Dense::from(*p);
p.clone() * p.clone()
}
Pow(c, e) => {
// FIXME: dummy implementation
let p = Dense::from(*c);
let mut result = p.clone();
for _ in 0..e {
result = result.clone() * p.clone();
}
result
}
Cache(_c, _) => {
unimplemented!("The module prime is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor")
}
IfFeature(_c, _t, _f) => {
unimplemented!("The module prime is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor")
}
}
}
}

// TODO: implement From/To Expr<F, Column>
99 changes: 99 additions & 0 deletions mvpoly/tests/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,102 @@ fn test_from_variable_column() {
assert_eq!(p[3], Fp::zero());
assert_eq!(p[4], Fp::one());
}

#[test]
fn test_evaluation_zero_polynomial() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let zero = Dense::<Fp, 4, 5>::zero();
let evaluation = zero.eval(&random_evaluation);
assert_eq!(evaluation, Fp::zero());
}

#[test]
fn test_evaluation_constant_polynomial() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let cst = Fp::rand(&mut rng);
let zero = Dense::<Fp, 4, 5>::from(cst);
let evaluation = zero.eval(&random_evaluation);
assert_eq!(evaluation, cst);
}

#[test]
fn test_evaluation_predefined_polynomial() {
// Evaluating at random points
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 2] = std::array::from_fn(|_| Fp::rand(&mut rng));
// P(X1, X2) = 2 + 3X1 + 4X2 + 5X1^2 + 6X1 X2 + 7 X2^2
let p = Dense::<Fp, 2, 2>::from_coeffs(vec![
Fp::from(2_u32),
Fp::from(3_u32),
Fp::from(4_u32),
Fp::from(5_u32),
Fp::from(6_u32),
Fp::from(7_u32),
]);
let exp_eval = Fp::from(2_u32)
+ Fp::from(3_u32) * random_evaluation[0]
+ Fp::from(4_u32) * random_evaluation[1]
+ Fp::from(5_u32) * random_evaluation[0] * random_evaluation[0]
+ Fp::from(6_u32) * random_evaluation[0] * random_evaluation[1]
+ Fp::from(7_u32) * random_evaluation[1] * random_evaluation[1];
let evaluation = p.eval(&random_evaluation);
assert_eq!(evaluation, exp_eval);
}

#[test]
fn test_eval_pbt_add() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let p1 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let p2 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let p3 = p1.clone() + p2.clone();
let eval_p1 = p1.eval(&random_evaluation);
let eval_p2 = p2.eval(&random_evaluation);
let eval_p3 = p3.eval(&random_evaluation);
assert_eq!(eval_p3, eval_p1 + eval_p2);
}

#[test]
fn test_eval_pbt_sub() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let p1 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let p2 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let p3 = p1.clone() - p2.clone();
let eval_p1 = p1.eval(&random_evaluation);
let eval_p2 = p2.eval(&random_evaluation);
let eval_p3 = p3.eval(&random_evaluation);
assert_eq!(eval_p3, eval_p1 - eval_p2);
}

#[test]
fn test_eval_pbt_mul_by_scalar() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let p1 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let c = Fp::rand(&mut rng);
let p2 = p1.clone() * Dense::<Fp, 6, 4>::from(c);
let eval_p1 = p1.eval(&random_evaluation);
let eval_p2 = p2.eval(&random_evaluation);
assert_eq!(eval_p2, eval_p1 * c);
}

#[test]
fn test_eval_pbt_neg() {
let mut rng = o1_utils::tests::make_test_rng(None);

let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let p1 = unsafe { Dense::<Fp, 6, 4>::random(&mut rng) };
let p2 = -p1.clone();
let eval_p1 = p1.eval(&random_evaluation);
let eval_p2 = p2.eval(&random_evaluation);
assert_eq!(eval_p2, -eval_p1);
}

0 comments on commit 99ad37d

Please sign in to comment.