-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] Add Poseidon Hasher Chip (#110)
* Add Poseidon chip * chore: minor fixes * test(poseidon): add compatbility tests Cherry-picked from #98 Co-authored-by: Antonio Mejías Gil <anmegi.95@gmail.com> * chore: minor refactor to more closely match snark-verifier https://github.com/axiom-crypto/snark-verifier/blob/main/snark-verifier/src/util/hash/poseidon.rs --------- Co-authored-by: Xinding Wei <xinding@intrinsictech.xyz> Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Co-authored-by: Antonio Mejías Gil <anmegi.95@gmail.com>
- Loading branch information
1 parent
49aeedd
commit a7b5433
Showing
9 changed files
with
788 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#![allow(clippy::needless_range_loop)] | ||
use crate::utils::ScalarField; | ||
|
||
/// The type used to hold the MDS matrix | ||
pub(crate) type Mds<F, const T: usize> = [[F; T]; T]; | ||
|
||
/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is | ||
/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce | ||
/// number of multiplications in apply MDS step | ||
#[derive(Debug, Clone)] | ||
pub struct MDSMatrices<F: ScalarField, const T: usize, const RATE: usize> { | ||
pub(crate) mds: MDSMatrix<F, T, RATE>, | ||
pub(crate) pre_sparse_mds: MDSMatrix<F, T, RATE>, | ||
pub(crate) sparse_matrices: Vec<SparseMDSMatrix<F, T, RATE>>, | ||
} | ||
|
||
/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear | ||
/// layer of partial rounds instead of the original MDS | ||
#[derive(Debug, Clone)] | ||
pub struct SparseMDSMatrix<F: ScalarField, const T: usize, const RATE: usize> { | ||
pub(crate) row: [F; T], | ||
pub(crate) col_hat: [F; RATE], | ||
} | ||
|
||
/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon | ||
#[derive(Clone, Debug)] | ||
pub struct MDSMatrix<F: ScalarField, const T: usize, const RATE: usize>(pub(crate) Mds<F, T>); | ||
|
||
impl<F: ScalarField, const T: usize, const RATE: usize> MDSMatrix<F, T, RATE> { | ||
pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { | ||
let mut res = [F::ZERO; T]; | ||
for i in 0..T { | ||
for j in 0..T { | ||
res[i] += self.0[i][j] * v[j]; | ||
} | ||
} | ||
res | ||
} | ||
|
||
pub(crate) fn identity() -> Mds<F, T> { | ||
let mut mds = [[F::ZERO; T]; T]; | ||
for i in 0..T { | ||
mds[i][i] = F::ONE; | ||
} | ||
mds | ||
} | ||
|
||
/// Multiplies two MDS matrices. Used in sparse matrix calculations | ||
pub(crate) fn mul(&self, other: &Self) -> Self { | ||
let mut res = [[F::ZERO; T]; T]; | ||
for i in 0..T { | ||
for j in 0..T { | ||
for k in 0..T { | ||
res[i][j] += self.0[i][k] * other.0[k][j]; | ||
} | ||
} | ||
} | ||
Self(res) | ||
} | ||
|
||
pub(crate) fn transpose(&self) -> Self { | ||
let mut res = [[F::ZERO; T]; T]; | ||
for i in 0..T { | ||
for j in 0..T { | ||
res[i][j] = self.0[j][i]; | ||
} | ||
} | ||
Self(res) | ||
} | ||
|
||
pub(crate) fn determinant<const N: usize>(m: [[F; N]; N]) -> F { | ||
let mut res = F::ONE; | ||
let mut m = m; | ||
for i in 0..N { | ||
let mut pivot = i; | ||
while m[pivot][i] == F::ZERO { | ||
pivot += 1; | ||
assert!(pivot < N, "matrix is not invertible"); | ||
} | ||
if pivot != i { | ||
res = -res; | ||
m.swap(pivot, i); | ||
} | ||
res *= m[i][i]; | ||
let inv = m[i][i].invert().unwrap(); | ||
for j in i + 1..N { | ||
let factor = m[j][i] * inv; | ||
for k in i + 1..N { | ||
m[j][k] -= m[i][k] * factor; | ||
} | ||
} | ||
} | ||
res | ||
} | ||
|
||
/// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf | ||
/// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`. | ||
/// Resulted `M''` matrices are the sparse ones while `M'` will contribute | ||
/// to the accumulator of the process | ||
pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix<F, T, RATE>) { | ||
assert_eq!(RATE + 1, T); | ||
// Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in | ||
// form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix | ||
let prime = |hat: Mds<F, RATE>| -> Self { | ||
let mut prime = Self::identity(); | ||
for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) { | ||
for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) { | ||
*el_prime = *el_hat; | ||
} | ||
} | ||
Self(prime) | ||
}; | ||
|
||
// Given `(t-1)` sized `w_hat` vector constructs the matrix in form | ||
// `[[m_0_0 | m_0_i], [w_hat | identity]]` | ||
let prime_prime = |w_hat: [F; RATE]| -> Mds<F, T> { | ||
let mut prime_prime = Self::identity(); | ||
prime_prime[0] = self.0[0]; | ||
for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) { | ||
row[0] = *w | ||
} | ||
prime_prime | ||
}; | ||
|
||
let w = self.0.iter().skip(1).map(|row| row[0]).collect::<Vec<_>>(); | ||
// m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 | ||
let mut m_hat = [[F::ZERO; RATE]; RATE]; | ||
for i in 0..RATE { | ||
for j in 0..RATE { | ||
m_hat[i][j] = self.0[i + 1][j + 1]; | ||
} | ||
} | ||
// w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult | ||
// we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule | ||
let mut w_hat = [F::ZERO; RATE]; | ||
let det = Self::determinant(m_hat); | ||
let det_inv = Option::<F>::from(det.invert()).expect("matrix is not invertible"); | ||
for j in 0..RATE { | ||
let mut m_hat_j = m_hat; | ||
for i in 0..RATE { | ||
m_hat_j[i][j] = w[i]; | ||
} | ||
w_hat[j] = Self::determinant(m_hat_j) * det_inv; | ||
} | ||
let m_prime = prime(m_hat); | ||
let m_prime_prime = prime_prime(w_hat); | ||
// row = first row of m_prime_prime.transpose() = first column of m_prime_prime | ||
let row: [F; T] = | ||
m_prime_prime.iter().map(|row| row[0]).collect::<Vec<_>>().try_into().unwrap(); | ||
// col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element | ||
let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap(); | ||
(m_prime, SparseMDSMatrix { row, col_hat }) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
use std::mem; | ||
|
||
use crate::{ | ||
gates::GateInstructions, | ||
poseidon::{spec::OptimizedPoseidonSpec, state::PoseidonState}, | ||
AssignedValue, Context, ScalarField, | ||
}; | ||
|
||
#[cfg(test)] | ||
mod tests; | ||
|
||
/// Module for maximum distance separable matrix operations. | ||
pub mod mds; | ||
/// Module for poseidon specification. | ||
pub mod spec; | ||
/// Module for poseidon states. | ||
pub mod state; | ||
|
||
/// Chip for Poseidon hasher. The chip is stateful. | ||
pub struct PoseidonHasherChip<F: ScalarField, const T: usize, const RATE: usize> { | ||
init_state: PoseidonState<F, T, RATE>, | ||
state: PoseidonState<F, T, RATE>, | ||
spec: OptimizedPoseidonSpec<F, T, RATE>, | ||
absorbing: Vec<AssignedValue<F>>, | ||
} | ||
|
||
impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T, RATE> { | ||
/// Create new Poseidon hasher chip. | ||
pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>( | ||
ctx: &mut Context<F>, | ||
) -> Self { | ||
let init_state = PoseidonState::default(ctx); | ||
let state = init_state.clone(); | ||
Self { | ||
init_state, | ||
state, | ||
spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(), | ||
absorbing: Vec::new(), | ||
} | ||
} | ||
|
||
/// Initialize a poseidon hasher from an existing spec. | ||
pub fn from_spec(ctx: &mut Context<F>, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self { | ||
let init_state = PoseidonState::default(ctx); | ||
Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } | ||
} | ||
|
||
/// Reset state to default and clear the buffer. | ||
pub fn clear(&mut self) { | ||
self.state = self.init_state.clone(); | ||
self.absorbing.clear(); | ||
} | ||
|
||
/// Store given `elements` into buffer. | ||
pub fn update(&mut self, elements: &[AssignedValue<F>]) { | ||
self.absorbing.extend_from_slice(elements); | ||
} | ||
|
||
/// Consume buffer and perform permutation, then output second element of | ||
/// state. | ||
pub fn squeeze( | ||
&mut self, | ||
ctx: &mut Context<F>, | ||
gate: &impl GateInstructions<F>, | ||
) -> AssignedValue<F> { | ||
let input_elements = mem::take(&mut self.absorbing); | ||
let exact = input_elements.len() % RATE == 0; | ||
|
||
for chunk in input_elements.chunks(RATE) { | ||
self.permutation(ctx, gate, chunk.to_vec()); | ||
} | ||
if exact { | ||
self.permutation(ctx, gate, vec![]); | ||
} | ||
|
||
self.state.s[1] | ||
} | ||
|
||
fn permutation( | ||
&mut self, | ||
ctx: &mut Context<F>, | ||
gate: &impl GateInstructions<F>, | ||
inputs: Vec<AssignedValue<F>>, | ||
) { | ||
let r_f = self.spec.r_f / 2; | ||
let mds = &self.spec.mds_matrices.mds.0; | ||
let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; | ||
let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; | ||
|
||
// First half of the full round | ||
let constants = &self.spec.constants.start; | ||
self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); | ||
for constants in constants.iter().skip(1).take(r_f - 1) { | ||
self.state.sbox_full(ctx, gate, constants); | ||
self.state.apply_mds(ctx, gate, mds); | ||
} | ||
self.state.sbox_full(ctx, gate, constants.last().unwrap()); | ||
self.state.apply_mds(ctx, gate, pre_sparse_mds); | ||
|
||
// Partial rounds | ||
let constants = &self.spec.constants.partial; | ||
for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { | ||
self.state.sbox_part(ctx, gate, constant); | ||
self.state.apply_sparse_mds(ctx, gate, sparse_mds); | ||
} | ||
|
||
// Second half of the full rounds | ||
let constants = &self.spec.constants.end; | ||
for constants in constants.iter() { | ||
self.state.sbox_full(ctx, gate, constants); | ||
self.state.apply_mds(ctx, gate, mds); | ||
} | ||
self.state.sbox_full(ctx, gate, &[F::ZERO; T]); | ||
self.state.apply_mds(ctx, gate, mds); | ||
} | ||
} |
Oops, something went wrong.