Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fixed_base::msm_par handles identity point #48

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions halo2-ecc/src/bn254/tests/fixed_base_msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use halo2_base::{
halo2_proofs::halo2curves::bn256::G1,
utils::fs::gen_srs,
};
use itertools::Itertools;
use rand_core::OsRng;

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -68,6 +69,7 @@ fn fixed_base_msm_test(

fn random_fixed_base_msm_circuit(
params: MSMCircuitParams,
bases: Vec<G1Affine>, // bases are fixed in vkey so don't randomly generate
stage: CircuitBuilderStage,
break_points: Option<MultiPhaseThreadBreakPoints>,
) -> RangeCircuitBuilder<Fr> {
Expand All @@ -78,8 +80,7 @@ fn random_fixed_base_msm_circuit(
CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(),
};

let (bases, scalars): (Vec<_>, Vec<_>) =
(0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip();
let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec();
let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage"));
fixed_base_msm_test(&mut builder, params, bases, scalars);

Expand All @@ -106,7 +107,8 @@ fn test_fixed_base_msm() {
)
.unwrap();

let circuit = random_fixed_base_msm_circuit(params, CircuitBuilderStage::Mock, None);
let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec();
let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None);
MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied();
}

Expand All @@ -132,8 +134,13 @@ fn bench_fixed_base_msm() -> Result<(), Box<dyn std::error::Error>> {
let params = gen_srs(k);
println!("{bench_params:?}");

let circuit =
random_fixed_base_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None);
let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec();
let circuit = random_fixed_base_msm_circuit(
bench_params,
bases.clone(),
CircuitBuilderStage::Keygen,
None,
);

let vk_time = start_timer!(|| "Generating vkey");
let vk = keygen_vk(&params, &circuit)?;
Expand All @@ -149,6 +156,7 @@ fn bench_fixed_base_msm() -> Result<(), Box<dyn std::error::Error>> {
let proof_time = start_timer!(|| "Proving time");
let circuit = random_fixed_base_msm_circuit(
bench_params,
bases,
CircuitBuilderStage::Prover,
Some(break_points),
);
Expand Down
1 change: 1 addition & 0 deletions halo2-ecc/src/ecc/ecdsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ where
u1.limbs().to_vec(),
base_chip.limb_bits,
fixed_window_bits,
true, // we can call it with scalar_is_safe = true because of the u1_small check below
);
let u2_mul = scalar_multiply(
base_chip,
Expand Down
152 changes: 34 additions & 118 deletions halo2-ecc/src/ecc/fixed_base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(non_snake_case)]
use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
use crate::ecc::ec_sub_strict;
use crate::fields::{FieldChip, PrimeField, Selectable};
use group::Curve;
use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder};
Expand All @@ -8,21 +9,25 @@ use itertools::Itertools;
use rayon::prelude::*;
use std::cmp::min;

// computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant)
// - `scalar` is represented as a non-empty reference array of `AssignedValue`s
// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
// assumes:
// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits)
// - `max_bits <= modulus::<F>.bits()`

/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant)
/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s
/// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
///
/// # Assumptions
/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits)
/// - `scalar > 0`
/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`)
/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`)
/// - `max_bits <= modulus::<F>.bits()`
pub fn scalar_multiply<F, FC, C>(
chip: &FC,
ctx: &mut Context<F>,
point: &C,
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint>
where
F: PrimeField,
Expand All @@ -33,8 +38,8 @@ where
let zero = chip.load_constant(ctx, C::Base::zero());
return EcPoint::new(zero.clone(), zero);
}
debug_assert!(!scalar.is_empty());
debug_assert!((max_bits as u32) <= F::NUM_BITS);
assert!(!scalar.is_empty());
assert!((max_bits as u32) <= F::NUM_BITS);

let total_bits = max_bits * scalar.len();
let num_windows = (total_bits + window_bits - 1) / window_bits;
Expand Down Expand Up @@ -91,7 +96,7 @@ where
let is_zero_window = chip.gate().is_zero(ctx, bit_sum);
let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window);
curr_point = if let Some(curr_point) = curr_point {
let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, false);
let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe);
let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window);
Some(ec_select(chip, ctx, zero_sum, add_point, is_started))
} else {
Expand All @@ -107,117 +112,16 @@ where
curr_point.unwrap()
}

/* To reduce total amount of code, just always use msm_par below.
// basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation
// we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO)
pub fn msm<F, FC, C>(
chip: &EccChip<F, FC>,
ctx: &mut Context<F>,
points: &[C],
scalars: Vec<Vec<AssignedValue<F>>>,
max_scalar_bits_per_cell: usize,
window_bits: usize,
) -> EcPoint<F, FC::FieldPoint>
where
F: PrimeField,
C: CurveAffineExt,
FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
{
assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS);
let scalar_len = scalars[0].len();
let total_bits = max_scalar_bits_per_cell * scalar_len;
let num_windows = (total_bits + window_bits - 1) / window_bits;

// `cached_points` is a flattened 2d vector
// first we compute all cached points in Jacobian coordinates since it's fastest
let cached_points_jacobian = points
.iter()
.flat_map(|point| {
let base_pt = point.to_curve();
// cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1}
let mut increment = base_pt;
(0..num_windows)
.flat_map(|i| {
let mut curr = increment;
let cache_vec = std::iter::once(increment)
.chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(
|_| {
let prev = curr;
curr += increment;
prev
},
))
.collect_vec();
increment = curr;
cache_vec
})
.collect_vec()
})
.collect_vec();
// for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive
// initialize to all 0s
let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);

let field_chip = chip.field_chip();
let cached_points = cached_points_affine
.into_iter()
.map(|point| chip.assign_constant_point(ctx, point))
.collect_vec();

let bits = scalars
.into_iter()
.flat_map(|scalar| {
assert_eq!(scalar.len(), scalar_len);
scalar
.into_iter()
.flat_map(|scalar_chunk| {
field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell)
})
.collect_vec()
})
.collect_vec();

let scalar_mults = cached_points
.chunks(cached_points.len() / points.len())
.zip(bits.chunks(total_bits))
.map(|(cached_points, bits)| {
let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
let bit_window_rev = bits.chunks(window_bits).rev();
let mut curr_point = None;
// `is_started` is just a way to deal with if `curr_point` is actually identity
let mut is_started = ctx.load_zero();
for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
let is_zero_window = {
let sum = field_chip.gate().sum(ctx, bit_window.iter().copied());
field_chip.gate().is_zero(ctx, sum)
};
let add_point =
ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window);
curr_point = if let Some(curr_point) = curr_point {
let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false);
let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window);
Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started))
} else {
Some(add_point)
};
is_started = {
// is_started || !is_zero_window
// (a || !b) = (1-b) + a*b
let not_zero_window = field_chip.gate().not(ctx, is_zero_window);
field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window)
};
}
curr_point.unwrap()
})
.collect_vec();
chip.sum::<C>(ctx, scalar_mults)
}
*/

/// # Assumptions
/// * `points.len() = scalars.len()`
/// * `scalars[i].len() = scalars[j].len()` for all `i,j`
/// * `points` are all on the curve
/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand
/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise)
/// * Output may be point at infinity, in which case (0, 0) is returned
pub fn msm_par<F, FC, C>(
chip: &EccChip<F, FC>,
builder: &mut GateThreadBuilder<F>,
Expand All @@ -232,6 +136,9 @@ where
C: CurveAffineExt,
FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
{
if points.is_empty() {
return chip.assign_constant_point(builder.main(phase), C::identity());
}
assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS);
assert_eq!(points.len(), scalars.len());
assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point");
Expand Down Expand Up @@ -306,6 +213,7 @@ where
let add_point =
ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window);
curr_point = if let Some(curr_point) = curr_point {
// We don't need strict mode because we assume scalars[i] is less than the order of points[i]
let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false);
let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window);
Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started))
Expand All @@ -319,8 +227,16 @@ where
field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window)
};
}
curr_point.unwrap()
(curr_point.unwrap(), is_started)
},
);
chip.sum::<C>(builder.main(phase), scalar_mults)
let ctx = builder.main(phase);
// sum `scalar_mults` but take into account possiblity of identity points
let any_point = chip.load_random_point::<C>(ctx);
let mut acc = any_point.clone();
for (point, is_not_identity) in scalar_mults {
let new_acc = chip.add_unequal(ctx, &acc, point, true);
acc = chip.select(ctx, new_acc, acc, is_not_identity);
}
ec_sub_strict(field_chip, ctx, acc, any_point)
}
14 changes: 9 additions & 5 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,12 @@ where
/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
///
/// # Assumptions
/// * `P` is not the point at infinity
/// * `scalar > 0`
/// * If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`)
/// * `scalar_i < 2^{max_bits} for all i`
/// * `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
/// - `P` is not the point at infinity
/// - `scalar > 0`
/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`)
/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`)
/// - `scalar_i < 2^{max_bits} for all i`
/// - `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
pub fn scalar_multiply<F: PrimeField, FC>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -1094,6 +1095,7 @@ where
}

impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
/// See [`fixed_base::scalar_multiply`] for more details.
// TODO: put a check in place that scalar is < modulus of C::Scalar
pub fn fixed_base_scalar_mult<C>(
&self,
Expand All @@ -1102,6 +1104,7 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint>
where
C: CurveAffineExt,
Expand All @@ -1114,6 +1117,7 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
scalar,
max_bits,
window_bits,
scalar_is_safe,
)
}

Expand Down