Skip to content

Commit

Permalink
[Feat] Add Poseidon Hasher Chip (#110)
Browse files Browse the repository at this point in the history
* Add Poseidon chip

* chore: minor fixes

* test(poseidon): add compatbility tests

Cherry-picked from #98

Co-authored-by: Antonio Mejías Gil <anmegi.95@gmail.com>

* chore: minor refactor to more closely match snark-verifier

https://github.com/axiom-crypto/snark-verifier/blob/main/snark-verifier/src/util/hash/poseidon.rs

---------

Co-authored-by: Xinding Wei <xinding@intrinsictech.xyz>
Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com>
Co-authored-by: Antonio Mejías Gil <anmegi.95@gmail.com>
  • Loading branch information
4 people authored Aug 17, 2023
1 parent 49aeedd commit a7b5433
Show file tree
Hide file tree
Showing 9 changed files with 788 additions and 2 deletions.
5 changes: 5 additions & 0 deletions halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", packag
# Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on
halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "f348757", optional = true }

# This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile).
# We forked it to upgrade to ff v0.13 and removed the circuit module
poseidon-rs = { git = "https://github.com/axiom-crypto/poseidon-circuit.git", rev = "1aee4a1" }
# plotting circuit layout
plotters = { version = "0.3.0", optional = true }
tabbycat = { version = "0.1", features = ["attributes"], optional = true }
Expand All @@ -35,6 +38,8 @@ criterion = "0.4"
criterion-macro = "0.4"
test-case = "3.1.0"
proptest = "1.1.0"
# native poseidon for testing
pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" }

# memory allocation
[target.'cfg(not(target_env = "msvc"))'.dependencies]
Expand Down
2 changes: 2 additions & 0 deletions halo2-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use utils::ScalarField;

/// Module that contains the main API for creating and working with circuits.
pub mod gates;
/// Module for the Poseidon hash function.
pub mod poseidon;
/// Module for SafeType which enforce value range and realted functions.
pub mod safe_types;
/// Utility functions for converting between different types of field elements.
Expand Down
154 changes: 154 additions & 0 deletions halo2-base/src/poseidon/mds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#![allow(clippy::needless_range_loop)]
use crate::utils::ScalarField;

/// The type used to hold the MDS matrix
pub(crate) type Mds<F, const T: usize> = [[F; T]; T];

/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is
/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce
/// number of multiplications in apply MDS step
#[derive(Debug, Clone)]
pub struct MDSMatrices<F: ScalarField, const T: usize, const RATE: usize> {
pub(crate) mds: MDSMatrix<F, T, RATE>,
pub(crate) pre_sparse_mds: MDSMatrix<F, T, RATE>,
pub(crate) sparse_matrices: Vec<SparseMDSMatrix<F, T, RATE>>,
}

/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear
/// layer of partial rounds instead of the original MDS
#[derive(Debug, Clone)]
pub struct SparseMDSMatrix<F: ScalarField, const T: usize, const RATE: usize> {
pub(crate) row: [F; T],
pub(crate) col_hat: [F; RATE],
}

/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon
#[derive(Clone, Debug)]
pub struct MDSMatrix<F: ScalarField, const T: usize, const RATE: usize>(pub(crate) Mds<F, T>);

impl<F: ScalarField, const T: usize, const RATE: usize> MDSMatrix<F, T, RATE> {
pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] {
let mut res = [F::ZERO; T];
for i in 0..T {
for j in 0..T {
res[i] += self.0[i][j] * v[j];
}
}
res
}

pub(crate) fn identity() -> Mds<F, T> {
let mut mds = [[F::ZERO; T]; T];
for i in 0..T {
mds[i][i] = F::ONE;
}
mds
}

/// Multiplies two MDS matrices. Used in sparse matrix calculations
pub(crate) fn mul(&self, other: &Self) -> Self {
let mut res = [[F::ZERO; T]; T];
for i in 0..T {
for j in 0..T {
for k in 0..T {
res[i][j] += self.0[i][k] * other.0[k][j];
}
}
}
Self(res)
}

pub(crate) fn transpose(&self) -> Self {
let mut res = [[F::ZERO; T]; T];
for i in 0..T {
for j in 0..T {
res[i][j] = self.0[j][i];
}
}
Self(res)
}

pub(crate) fn determinant<const N: usize>(m: [[F; N]; N]) -> F {
let mut res = F::ONE;
let mut m = m;
for i in 0..N {
let mut pivot = i;
while m[pivot][i] == F::ZERO {
pivot += 1;
assert!(pivot < N, "matrix is not invertible");
}
if pivot != i {
res = -res;
m.swap(pivot, i);
}
res *= m[i][i];
let inv = m[i][i].invert().unwrap();
for j in i + 1..N {
let factor = m[j][i] * inv;
for k in i + 1..N {
m[j][k] -= m[i][k] * factor;
}
}
}
res
}

/// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf
/// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`.
/// Resulted `M''` matrices are the sparse ones while `M'` will contribute
/// to the accumulator of the process
pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix<F, T, RATE>) {
assert_eq!(RATE + 1, T);
// Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in
// form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix
let prime = |hat: Mds<F, RATE>| -> Self {
let mut prime = Self::identity();
for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) {
for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) {
*el_prime = *el_hat;
}
}
Self(prime)
};

// Given `(t-1)` sized `w_hat` vector constructs the matrix in form
// `[[m_0_0 | m_0_i], [w_hat | identity]]`
let prime_prime = |w_hat: [F; RATE]| -> Mds<F, T> {
let mut prime_prime = Self::identity();
prime_prime[0] = self.0[0];
for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) {
row[0] = *w
}
prime_prime
};

let w = self.0.iter().skip(1).map(|row| row[0]).collect::<Vec<_>>();
// m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0
let mut m_hat = [[F::ZERO; RATE]; RATE];
for i in 0..RATE {
for j in 0..RATE {
m_hat[i][j] = self.0[i + 1][j + 1];
}
}
// w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult
// we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule
let mut w_hat = [F::ZERO; RATE];
let det = Self::determinant(m_hat);
let det_inv = Option::<F>::from(det.invert()).expect("matrix is not invertible");
for j in 0..RATE {
let mut m_hat_j = m_hat;
for i in 0..RATE {
m_hat_j[i][j] = w[i];
}
w_hat[j] = Self::determinant(m_hat_j) * det_inv;
}
let m_prime = prime(m_hat);
let m_prime_prime = prime_prime(w_hat);
// row = first row of m_prime_prime.transpose() = first column of m_prime_prime
let row: [F; T] =
m_prime_prime.iter().map(|row| row[0]).collect::<Vec<_>>().try_into().unwrap();
// col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element
let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap();
(m_prime, SparseMDSMatrix { row, col_hat })
}
}
116 changes: 116 additions & 0 deletions halo2-base/src/poseidon/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::mem;

use crate::{
gates::GateInstructions,
poseidon::{spec::OptimizedPoseidonSpec, state::PoseidonState},
AssignedValue, Context, ScalarField,
};

#[cfg(test)]
mod tests;

/// Module for maximum distance separable matrix operations.
pub mod mds;
/// Module for poseidon specification.
pub mod spec;
/// Module for poseidon states.
pub mod state;

/// Chip for Poseidon hasher. The chip is stateful.
pub struct PoseidonHasherChip<F: ScalarField, const T: usize, const RATE: usize> {
init_state: PoseidonState<F, T, RATE>,
state: PoseidonState<F, T, RATE>,
spec: OptimizedPoseidonSpec<F, T, RATE>,
absorbing: Vec<AssignedValue<F>>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T, RATE> {
/// Create new Poseidon hasher chip.
pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
ctx: &mut Context<F>,
) -> Self {
let init_state = PoseidonState::default(ctx);
let state = init_state.clone();
Self {
init_state,
state,
spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(),
absorbing: Vec::new(),
}
}

/// Initialize a poseidon hasher from an existing spec.
pub fn from_spec(ctx: &mut Context<F>, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
let init_state = PoseidonState::default(ctx);
Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() }
}

/// Reset state to default and clear the buffer.
pub fn clear(&mut self) {
self.state = self.init_state.clone();
self.absorbing.clear();
}

/// Store given `elements` into buffer.
pub fn update(&mut self, elements: &[AssignedValue<F>]) {
self.absorbing.extend_from_slice(elements);
}

/// Consume buffer and perform permutation, then output second element of
/// state.
pub fn squeeze(
&mut self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
) -> AssignedValue<F> {
let input_elements = mem::take(&mut self.absorbing);
let exact = input_elements.len() % RATE == 0;

for chunk in input_elements.chunks(RATE) {
self.permutation(ctx, gate, chunk.to_vec());
}
if exact {
self.permutation(ctx, gate, vec![]);
}

self.state.s[1]
}

fn permutation(
&mut self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
inputs: Vec<AssignedValue<F>>,
) {
let r_f = self.spec.r_f / 2;
let mds = &self.spec.mds_matrices.mds.0;
let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0;
let sparse_matrices = &self.spec.mds_matrices.sparse_matrices;

// First half of the full round
let constants = &self.spec.constants.start;
self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
for constants in constants.iter().skip(1).take(r_f - 1) {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, constants.last().unwrap());
self.state.apply_mds(ctx, gate, pre_sparse_mds);

// Partial rounds
let constants = &self.spec.constants.partial;
for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
self.state.sbox_part(ctx, gate, constant);
self.state.apply_sparse_mds(ctx, gate, sparse_mds);
}

// Second half of the full rounds
let constants = &self.spec.constants.end;
for constants in constants.iter() {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, &[F::ZERO; T]);
self.state.apply_mds(ctx, gate, mds);
}
}
Loading

0 comments on commit a7b5433

Please sign in to comment.