Skip to content

Commit

Permalink
unified back to one implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyfloatersu committed Jan 21, 2025
1 parent ca87fb1 commit 8f81ccc
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 139 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
run: cargo build --all-features --release

- name: Run unit tests
run: cargo test --all-features --release --workspace -- --nocapture
run: cargo test --all-features --release --workspace

- name: Run E2E tests
run: ./scripts/test_recursion.py
Expand Down
42 changes: 9 additions & 33 deletions poly_commit/src/orion/base_field_impl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::iter;

use arith::{ExtensionField, Field, SimdField};
use gf2::GF2;
use itertools::izip;
use polynomials::{EqPolynomial, MultilinearExtension, RefMultiLinearPoly};
use transcript::Transcript;
Expand Down Expand Up @@ -72,24 +71,14 @@ where
let proximity_test_num = pk.proximity_repetitions::<EvalF>(PCS_SOUNDNESS_BITS);
let mut proximity_rows = vec![vec![EvalF::ZERO; msg_size]; proximity_test_num];

match F::NAME {
GF2::NAME => lut_open_linear_combine(
row_num,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
),
_ => simd_open_linear_combine(
row_num,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
),
}
simd_open_linear_combine(
row_num,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
);

// NOTE: working on evaluation on top of evaluation response
let mut scratch = vec![EvalF::ZERO; msg_size];
Expand Down Expand Up @@ -174,19 +163,6 @@ where
_ => return false,
};

match F::NAME {
GF2::NAME => lut_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
_ => simd_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
}
simd_verify_alphabet_check(&codeword, rl, &query_indices, &packed_interleaved_alphabets)
})
}
16 changes: 1 addition & 15 deletions poly_commit/src/orion/simd_field_agg_impl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::iter;

use arith::{Field, SimdField};
use gf2::GF2;
use gkr_field_config::GKRFieldConfig;
use itertools::izip;
use polynomials::{EqPolynomial, MultilinearExtension, RefMultiLinearPoly};
Expand Down Expand Up @@ -142,19 +141,6 @@ where
_ => return false,
};

match C::CircuitField::NAME {
GF2::NAME => lut_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
_ => simd_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
}
simd_verify_alphabet_check(&codeword, rl, &query_indices, &packed_interleaved_alphabets)
})
}
22 changes: 12 additions & 10 deletions poly_commit/src/orion/simd_field_agg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::marker::PhantomData;

use arith::{ExtensionField, Field, SimdField};
use ark_std::test_rng;
use gkr_field_config::{GKRFieldConfig, M31ExtConfig};
use gf2::GF2x128;
use gf2_128::GF2_128;
use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig};
use itertools::izip;
use mersenne31::{M31Ext3, M31x16};
use polynomials::{EqPolynomial, MultiLinearPoly};
Expand Down Expand Up @@ -184,15 +186,15 @@ where
#[test]
fn test_orion_simd_aggregate_verify() {
let parties = 16;
/*
(16..18).for_each(|num_var| {
test_orion_simd_aggregate_verify_helper::<
GF2ExtConfig,
GF2x128,
BytesHashTranscript<GF2_128, Keccak256hasher>,
>(parties, num_var)
});
*/

(16..18).for_each(|num_var| {
test_orion_simd_aggregate_verify_helper::<
GF2ExtConfig,
GF2x128,
BytesHashTranscript<GF2_128, Keccak256hasher>,
>(parties, num_var)
});

(12..15).for_each(|num_var| {
test_orion_simd_aggregate_verify_helper::<
M31ExtConfig,
Expand Down
42 changes: 9 additions & 33 deletions poly_commit/src/orion/simd_field_impl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::iter;

use arith::{ExtensionField, Field, SimdField};
use gf2::GF2;
use itertools::izip;
use polynomials::{EqPolynomial, MultilinearExtension, RefMultiLinearPoly};
use transcript::Transcript;
Expand Down Expand Up @@ -99,24 +98,14 @@ where
let proximity_test_num = pk.proximity_repetitions::<EvalF>(PCS_SOUNDNESS_BITS);
let mut proximity_rows = vec![vec![EvalF::ZERO; msg_size]; proximity_test_num];

match F::NAME {
GF2::NAME => lut_open_linear_combine(
row_num * SimdF::PACK_SIZE,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
),
_ => simd_open_linear_combine(
row_num * SimdF::PACK_SIZE,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
),
}
simd_open_linear_combine(
row_num * SimdF::PACK_SIZE,
&packed_evals,
&eq_col_coeffs,
&mut eval_row,
&mut proximity_rows,
transcript,
);

// NOTE: MT opening for point queries
let query_openings = orion_mt_openings(pk, transcript, scratch_pad);
Expand Down Expand Up @@ -203,19 +192,6 @@ where
_ => return false,
};

match F::NAME {
GF2::NAME => lut_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
_ => simd_verify_alphabet_check(
&codeword,
rl,
&query_indices,
&packed_interleaved_alphabets,
),
}
simd_verify_alphabet_check(&codeword, rl, &query_indices, &packed_interleaved_alphabets)
})
}
43 changes: 4 additions & 39 deletions poly_commit/src/orion/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ where
}

/*
* LINEAR OPERATIONS FOR GF2 (LOOKUP TABLE BASED)
* LINEAR OPERATIONS (LOOKUP TABLE BASED)
*/

pub struct SubsetSumLUTs<F: Field> {
Expand Down Expand Up @@ -311,6 +311,7 @@ impl<F: Field> SubsetSumLUTs<F> {
}
}

#[allow(unused)]
#[inline(always)]
pub(crate) fn lut_open_linear_combine<F, EvalF, SimdF, T>(
row_num: usize,
Expand Down Expand Up @@ -348,6 +349,7 @@ pub(crate) fn lut_open_linear_combine<F, EvalF, SimdF, T>(
drop(luts);
}

#[allow(unused)]
#[inline(always)]
pub(crate) fn lut_verify_alphabet_check<F, SimdF, ExtF>(
codeword: &[ExtF],
Expand Down Expand Up @@ -375,7 +377,7 @@ where
}

/*
* LINEAR OPERATIONS FOR MERSENNE31 (SIMD BASED)
* LINEAR OPERATIONS (SIMD BASED)
*/

#[inline(always)]
Expand Down Expand Up @@ -472,40 +474,3 @@ where
alphabet == codeword[index]
})
}

#[cfg(test)]
mod tests {
use arith::{Field, SimdField};
use ark_std::test_rng;
use gf2::{GF2x8, GF2};
use gf2_128::{GF2_128x8, GF2_128};
use itertools::Itertools;

use super::SubsetSumLUTs;

#[test]
fn test_lut_simd_inner_prod_consistency() {
if cfg!(target_endian = "big") {
println!("This is a BigEndian system.")
} else {
println!("This is a LittleEndian system.")
}

let mut rng = test_rng();

let weights: Vec<_> = (0..8).map(|_| GF2_128::random_unsafe(&mut rng)).collect();
let bases: Vec<_> = (0..8).map(|_| GF2::random_unsafe(&mut rng)).collect_vec();

let simd_weights = GF2_128x8::pack(&weights);
let simd_bases = GF2x8::pack(&bases);

let expected_simd_inner_prod: GF2_128 = (simd_weights * simd_bases).unpack().iter().sum();

let mut table = SubsetSumLUTs::new(8, 1);
table.build(&weights);

let actual_lut_inner_prod = table.lookup_and_sum(&vec![simd_bases]);

assert_eq!(expected_simd_inner_prod, actual_lut_inner_prod)
}
}
14 changes: 6 additions & 8 deletions poly_commit/tests/test_orion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arith::{ExtensionField, Field, SimdField};
use ark_std::test_rng;
use gf2::{GF2x128, GF2x64, GF2x8, GF2};
use gf2_128::GF2_128;
use gkr_field_config::{GKRFieldConfig, M31ExtConfig};
use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig};
use mersenne31::{M31Ext3, M31x16, M31};
use mpi_config::MPIConfig;
use poly_commit::*;
Expand Down Expand Up @@ -162,13 +162,11 @@ fn test_orion_for_expander_gkr_generics<C, ComPackF, T>(
fn test_orion_for_expander_gkr() {
let mpi_config = MPIConfig::new();

/*
test_orion_for_expander_gkr_generics::<
GF2ExtConfig,
GF2x128,
BytesHashTranscript<_, Keccak256hasher>,
>(&mpi_config, 16);
*/
test_orion_for_expander_gkr_generics::<
GF2ExtConfig,
GF2x128,
BytesHashTranscript<_, Keccak256hasher>,
>(&mpi_config, 16);

test_orion_for_expander_gkr_generics::<
M31ExtConfig,
Expand Down

0 comments on commit 8f81ccc

Please sign in to comment.