From 022ff5ddcd569517d50942fd06d48c1621ce601f Mon Sep 17 00:00:00 2001 From: daxpedda Date: Sun, 23 Jan 2022 09:46:40 +0100 Subject: [PATCH 1/5] Apply Rust traits to all public types and other improvements --- src/group/ristretto.rs | 3 + src/lib.rs | 25 ++++--- src/tests/voprf_test_vectors.rs | 4 +- src/voprf.rs | 112 +++++++++++++++++++------------- 4 files changed, 87 insertions(+), 57 deletions(-) diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index e03ed04..2ec62cd 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -22,6 +22,9 @@ use crate::voprf::{self, Mode}; use crate::{CipherSuite, Error, InternalError, Result}; /// [`Group`] implementation for Ristretto255. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +// `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428 +#[cfg(feature = "ristretto255")] pub struct Ristretto255; #[cfg(feature = "ristretto255-ciphersuite")] diff --git a/src/lib.rs b/src/lib.rs index e63f248..44a00e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,9 +88,8 @@ //! //! In the second step, the server takes as input the message from //! [NonVerifiableClient::blind] (a [BlindedElement]), and runs -//! [NonVerifiableServer::evaluate] to produce a -//! [NonVerifiableServerEvaluateResult], which consists of an -//! [EvaluationElement] to be sent to the client. +//! [NonVerifiableServer::evaluate] to produce [EvaluationElement] to be sent to +//! the client. //! //! ``` //! # #[cfg(feature = "ristretto255")] @@ -135,13 +134,13 @@ //! # use voprf::NonVerifiableServer; //! # let mut server_rng = OsRng; //! # let server = NonVerifiableServer::::new(&mut server_rng); -//! # let server_evaluate_result = server.evaluate( +//! # let message = server.evaluate( //! # &client_blind_result.message, //! # None, //! # ).expect("Unable to perform server evaluate"); //! let client_finalize_result = client_blind_result //! .state -//! .finalize(b"input", &server_evaluate_result.message, None) +//! .finalize(b"input", &message, None) //! .expect("Unable to perform client finalization"); //! //! println!("VOPRF output: {:?}", client_finalize_result.to_vec()); @@ -479,7 +478,12 @@ #![deny(unsafe_code)] #![no_std] -#![warn(clippy::cargo, clippy::missing_errors_doc, missing_docs)] +#![warn( + clippy::cargo, + clippy::missing_errors_doc, + missing_debug_implementations, + missing_docs +)] #![allow(clippy::multiple_crate_versions)] #[cfg(any(feature = "alloc", test))] @@ -513,8 +517,9 @@ pub use crate::serialization::{ pub use crate::voprf::VerifiableServerBatchEvaluateResult; pub use crate::voprf::{ BlindedElement, EvaluationElement, Mode, NonVerifiableClient, NonVerifiableClientBlindResult, - NonVerifiableServer, NonVerifiableServerEvaluateResult, PreparedEvaluationElement, - PreparedTscalar, Proof, VerifiableClient, VerifiableClientBatchFinalizeResult, - VerifiableClientBlindResult, VerifiableServer, VerifiableServerBatchEvaluateFinishResult, - VerifiableServerBatchEvaluatePrepareResult, VerifiableServerEvaluateResult, + NonVerifiableServer, PreparedEvaluationElement, PreparedTscalar, Proof, VerifiableClient, + VerifiableClientBatchFinalizeResult, VerifiableClientBlindResult, VerifiableServer, + VerifiableServerBatchEvaluateFinishResult, VerifiableServerBatchEvaluateFinishedMessages, + VerifiableServerBatchEvaluatePrepareResult, + VerifiableServerBatchEvaluatePreparedEvaluationElements, VerifiableServerEvaluateResult, }; diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index 7691811..b8d07fd 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -240,14 +240,14 @@ where for parameters in tvs { for i in 0..parameters.input.len() { let server = NonVerifiableServer::::new_with_key(¶meters.sksm)?; - let server_result = server.evaluate( + let message = server.evaluate( &BlindedElement::deserialize(¶meters.blinded_element[i])?, Some(¶meters.info), )?; assert_eq!( ¶meters.evaluation_element[i], - &server_result.message.serialize().as_slice() + &message.serialize().as_slice() ); } } diff --git a/src/voprf.rs b/src/voprf.rs index d30e119..80f4abe 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -37,7 +37,7 @@ const STR_VOPRF: [u8; 8] = *b"VOPRF08-"; /// Determines the mode of operation (either base mode or verifiable mode). This /// is only used for custom implementations for [`Group`]. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub enum Mode { /// Non-verifiable mode. Base, @@ -479,7 +479,7 @@ where &self, blinded_element: &BlindedElement, metadata: Option<&[u8]>, - ) -> Result> { + ) -> Result> { // https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.1.1-1 let context_string = get_context_string::(Mode::Base); @@ -499,9 +499,7 @@ where // Z = t^(-1) * R let z = blinded_element.0 * &CS::Group::invert_scalar(t); - Ok(NonVerifiableServerEvaluateResult { - message: EvaluationElement(z), - }) + Ok(EvaluationElement(z)) } } @@ -690,14 +688,14 @@ where u, evaluation_elements .into_iter() - .map(|element| element.0.copy()), - blinded_elements.map(BlindedElement::copy), + .map(|element| element.0.clone()), + blinded_elements.cloned(), )?; let messages = evaluation_elements .into_iter() .map() -> _>::from(|element| { - element.0.copy() + element.0.clone() })); Ok(VerifiableServerBatchEvaluateFinishResult { messages, proof }) @@ -715,6 +713,8 @@ where ///////////////////////// /// Contains the fields that are returned by a non-verifiable client blind +#[derive(DeriveWhere)] +#[derive_where(Debug; ::Scalar, ::Elem)] pub struct NonVerifiableClientBlindResult where ::OutputSize: @@ -726,17 +726,9 @@ where pub message: BlindedElement, } -/// Contains the fields that are returned by a non-verifiable server evaluate -pub struct NonVerifiableServerEvaluateResult -where - ::OutputSize: - IsLess + IsLessOrEqual<::BlockSize>, -{ - /// The message to send to the client - pub message: EvaluationElement, -} - /// Contains the fields that are returned by a verifiable client blind +#[derive(DeriveWhere)] +#[derive_where(Debug; ::Scalar, ::Elem)] pub struct VerifiableClientBlindResult where ::OutputSize: @@ -757,6 +749,8 @@ pub type VerifiableClientBatchFinalizeResult<'a, C, I, II, IC, IM> = FinalizeAft >; /// Contains the fields that are returned by a verifiable server evaluate +#[derive(DeriveWhere)] +#[derive_where(Debug; ::Scalar, ::Elem)] pub struct VerifiableServerEvaluateResult where ::OutputSize: @@ -770,6 +764,17 @@ where /// Contains prepared [`EvaluationElement`]s by a verifiable server batch /// evaluate preparation. +#[derive(DeriveWhere)] +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; ::Elem)] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(bound( + deserialize = "::Elem: serde::Deserialize<'de>", + serialize = "::Elem: serde::Serialize" + )) +)] pub struct PreparedEvaluationElement(EvaluationElement) where ::OutputSize: @@ -777,14 +782,37 @@ where /// Contains the prepared `t` by a verifiable server batch evaluate preparation. #[derive(DeriveWhere)] -#[derive_where(Zeroize(drop))] +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; ::Scalar)] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(bound( + deserialize = "::Scalar: serde::Deserialize<'de>", + serialize = "::Scalar: serde::Serialize" + )) +)] pub struct PreparedTscalar(::Scalar) where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>; +/// Concrete type of [`EvaluationElement`]s in +/// [`VerifiableServerBatchEvaluatePrepareResult`]. +pub type VerifiableServerBatchEvaluatePreparedEvaluationElements = Map< + Zip::Group as Group>::Scalar>>, + fn( + ( + &BlindedElement, + <::Group as Group>::Scalar, + ), + ) -> PreparedEvaluationElement, +>; + /// Contains the fields that are returned by a verifiable server batch evaluate /// preparation. +#[derive(DeriveWhere)] +#[derive_where(Debug; I, ::Scalar)] pub struct VerifiableServerBatchEvaluatePrepareResult< 'a, CS: 'a + CipherSuite, @@ -794,34 +822,38 @@ pub struct VerifiableServerBatchEvaluatePrepareResult< IsLess + IsLessOrEqual<::BlockSize>, { /// Prepared [`EvaluationElement`]s that will become messages. - #[allow(clippy::type_complexity)] - pub prepared_evaluation_elements: Map< - Zip::Scalar>>, - fn((&BlindedElement, ::Scalar)) -> PreparedEvaluationElement, - >, + pub prepared_evaluation_elements: + VerifiableServerBatchEvaluatePreparedEvaluationElements, /// Prepared `t` needed to finish the verifiable server batch evaluation. pub t: PreparedTscalar, } +/// Concrete type of [`EvaluationElement`]s in +/// [`VerifiableServerBatchEvaluateFinishResult`]. +pub type VerifiableServerBatchEvaluateFinishedMessages<'a, CS, I> = Map< + <&'a I as IntoIterator>::IntoIter, + fn(&PreparedEvaluationElement) -> EvaluationElement, +>; + /// Contains the fields that are returned by a verifiable server batch evaluate /// finish. +#[derive(DeriveWhere)] +#[derive_where(Debug; <&'a I as core::iter::IntoIterator>::IntoIter, ::Scalar)] pub struct VerifiableServerBatchEvaluateFinishResult<'a, CS: 'a + CipherSuite, I> where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, &'a I: IntoIterator>, { - /// The messages to send to the client - #[allow(clippy::type_complexity)] - pub messages: Map< - <&'a I as IntoIterator>::IntoIter, - fn(&PreparedEvaluationElement) -> EvaluationElement, - >, + /// The [`EvaluationElement`]s to send to the client + pub messages: VerifiableServerBatchEvaluateFinishedMessages<'a, CS, I>, /// The proof for the client to verify pub proof: Proof, } /// Contains the fields that are returned by a verifiable server batch evaluate +#[derive(DeriveWhere)] +#[derive_where(Debug; ::Scalar, ::Elem)] #[cfg(feature = "alloc")] pub struct VerifiableServerBatchEvaluateResult where @@ -844,11 +876,6 @@ where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { - /// Only used to easier validate allocation - fn copy(&self) -> Self { - Self(self.0) - } - #[cfg(feature = "danger")] /// Creates a [BlindedElement] from a raw group element. /// @@ -872,11 +899,6 @@ where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { - /// Only used to easier validate allocation - fn copy(&self) -> Self { - Self(self.0) - } - #[cfg(feature = "danger")] /// Creates an [EvaluationElement] from a raw group element. /// @@ -991,7 +1013,7 @@ where .into_iter() // Convert to `fn` pointer to make a return type possible. .map() -> _>::from(|x| x.blind)); - let evaluation_elements = messages.into_iter().map(EvaluationElement::copy); + let evaluation_elements = messages.into_iter().cloned(); let blinded_elements = clients .into_iter() .map(|client| BlindedElement(client.blinded_element)); @@ -1338,12 +1360,12 @@ mod tests { let mut rng = OsRng; let client_blind_result = NonVerifiableClient::::blind(input, &mut rng).unwrap(); let server = NonVerifiableServer::::new(&mut rng); - let server_result = server + let message = server .evaluate(&client_blind_result.message, Some(info)) .unwrap(); let client_finalize_result = client_blind_result .state - .finalize(input, &server_result.message, Some(info)) + .finalize(input, &message, Some(info)) .unwrap(); let res2 = prf::(input, server.get_private_key(), info, Mode::Base); assert_eq!(client_finalize_result, res2); @@ -1589,7 +1611,7 @@ mod tests { let mut rng = OsRng; let client_blind_result = NonVerifiableClient::::blind(input, &mut rng).unwrap(); let server = NonVerifiableServer::::new(&mut rng); - let server_result = server + let message = server .evaluate(&client_blind_result.message, Some(info)) .unwrap(); @@ -1597,7 +1619,7 @@ mod tests { Zeroize::zeroize(&mut state); assert!(state.serialize().iter().all(|&x| x == 0)); - let mut message = server_result.message; + let mut message = message; Zeroize::zeroize(&mut message); assert!(message.serialize().iter().all(|&x| x == 0)); } From d0cd0449f0ea85f937fb5c468611d63c830fda07 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Mon, 24 Jan 2022 12:13:49 +0100 Subject: [PATCH 2/5] Move methods into appropriate section --- src/voprf.rs | 100 +++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/src/voprf.rs b/src/voprf.rs index 80f4abe..8b2e924 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -707,6 +707,52 @@ where } } +impl BlindedElement +where + ::OutputSize: + IsLess + IsLessOrEqual<::BlockSize>, +{ + #[cfg(feature = "danger")] + /// Creates a [BlindedElement] from a raw group element. + /// + /// # Caution + /// + /// This should be used with caution, since it does not perform any checks + /// on the validity of the value itself! + pub fn from_value_unchecked(value: ::Elem) -> Self { + Self(value) + } + + #[cfg(feature = "danger")] + /// Exposes the internal value + pub fn value(&self) -> ::Elem { + self.0 + } +} + +impl EvaluationElement +where + ::OutputSize: + IsLess + IsLessOrEqual<::BlockSize>, +{ + #[cfg(feature = "danger")] + /// Creates an [EvaluationElement] from a raw group element. + /// + /// # Caution + /// + /// This should be used with caution, since it does not perform any checks + /// on the validity of the value itself! + pub fn from_value_unchecked(value: ::Elem) -> Self { + Self(value) + } + + #[cfg(feature = "danger")] + /// Exposes the internal value + pub fn value(&self) -> ::Elem { + self.0 + } +} + ///////////////////////// // Convenience Structs // //==================== // @@ -866,56 +912,10 @@ where pub proof: Proof, } -/////////////////////////////////////////////// -// Inner functions and Trait Implementations // -// ========================================= // -/////////////////////////////////////////////// - -impl BlindedElement -where - ::OutputSize: - IsLess + IsLessOrEqual<::BlockSize>, -{ - #[cfg(feature = "danger")] - /// Creates a [BlindedElement] from a raw group element. - /// - /// # Caution - /// - /// This should be used with caution, since it does not perform any checks - /// on the validity of the value itself! - pub fn from_value_unchecked(value: ::Elem) -> Self { - Self(value) - } - - #[cfg(feature = "danger")] - /// Exposes the internal value - pub fn value(&self) -> ::Elem { - self.0 - } -} - -impl EvaluationElement -where - ::OutputSize: - IsLess + IsLessOrEqual<::BlockSize>, -{ - #[cfg(feature = "danger")] - /// Creates an [EvaluationElement] from a raw group element. - /// - /// # Caution - /// - /// This should be used with caution, since it does not perform any checks - /// on the validity of the value itself! - pub fn from_value_unchecked(value: ::Elem) -> Self { - Self(value) - } - - #[cfg(feature = "danger")] - /// Exposes the internal value - pub fn value(&self) -> ::Elem { - self.0 - } -} +///////////////////// +// Inner functions // +// =============== // +///////////////////// type BlindResult = ( <::Group as Group>::Scalar, From 2760da9a6c79a34681e6625926528f482c47924c Mon Sep 17 00:00:00 2001 From: daxpedda Date: Mon, 24 Jan 2022 12:14:22 +0100 Subject: [PATCH 3/5] Check for zero scalars --- src/error.rs | 2 ++ src/group/elliptic_curve.rs | 4 ++++ src/group/mod.rs | 5 ++++- src/group/ristretto.rs | 5 +++++ src/voprf.rs | 26 ++++++++++++++++++++++---- 5 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/error.rs b/src/error.rs index 6a99da1..38a3240 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,8 @@ pub enum Error { ProofVerification, /// Size of seed is longer then [`u16::MAX`]. Seed, + /// The protocol has failed and can't be completed. + Protocol, } /// Only used to implement [`Group`](crate::Group). diff --git a/src/group/elliptic_curve.rs b/src/group/elliptic_curve.rs index 89390fa..fa9ca3a 100644 --- a/src/group/elliptic_curve.rs +++ b/src/group/elliptic_curve.rs @@ -103,6 +103,10 @@ where Option::from(scalar.invert()).unwrap() } + fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice { + scalar.is_zero() + } + #[cfg(test)] fn zero_scalar() -> Self::Scalar { Scalar::::zero() diff --git a/src/group/mod.rs b/src/group/mod.rs index 444b4b8..9306a71 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -20,7 +20,7 @@ use generic_array::{ArrayLength, GenericArray}; use rand_core::{CryptoRng, RngCore}; #[cfg(feature = "ristretto255")] pub use ristretto::Ristretto255; -use subtle::ConstantTimeEq; +use subtle::{Choice, ConstantTimeEq}; use zeroize::Zeroize; use crate::voprf::Mode; @@ -101,6 +101,9 @@ pub trait Group { /// The multiplicative inverse of this scalar fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar; + /// Returns `true` if the scalar is zero. + fn is_zero_scalar(scalar: Self::Scalar) -> Choice; + /// Returns the scalar representing zero #[cfg(test)] fn zero_scalar() -> Self::Scalar; diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 2ec62cd..6ac91c5 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -16,6 +16,7 @@ use generic_array::sequence::Concat; use generic_array::typenum::{IsLess, IsLessOrEqual, U256, U32, U64}; use generic_array::GenericArray; use rand_core::{CryptoRng, RngCore}; +use subtle::ConstantTimeEq; use super::{Group, STR_HASH_TO_GROUP, STR_HASH_TO_SCALAR}; use crate::voprf::{self, Mode}; @@ -127,6 +128,10 @@ impl Group for Ristretto255 { scalar.invert() } + fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice { + scalar.ct_eq(&Scalar::zero()) + } + #[cfg(test)] fn zero_scalar() -> Self::Scalar { Scalar::zero() diff --git a/src/voprf.rs b/src/voprf.rs index 8b2e924..68b98f5 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -474,7 +474,8 @@ where /// to the client. /// /// # Errors - /// [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Protocol`] if the protocol fails and can't be completed. pub fn evaluate( &self, blinded_element: &BlindedElement, @@ -496,6 +497,13 @@ where CS::Group::hash_to_scalar::(&context, Mode::Base).map_err(|_| Error::Metadata)?; // t = skS + m let t = self.sk + &m; + + // if t == 0: + if bool::from(CS::Group::is_zero_scalar(t)) { + // raise InverseError + return Err(Error::Protocol); + } + // Z = t^(-1) * R let z = blinded_element.0 * &CS::Group::invert_scalar(t); @@ -553,7 +561,8 @@ where /// to the client. /// /// # Errors - /// [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Protocol`] if the protocol fails and can't be completed. pub fn evaluate( &self, rng: &mut R, @@ -584,7 +593,8 @@ where /// messages from a [VerifiableClient] /// /// # Errors - /// [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Protocol`] if the protocol fails and can't be completed. #[cfg(feature = "alloc")] pub fn batch_evaluate<'a, R: RngCore + CryptoRng, I>( &self, @@ -626,7 +636,8 @@ where /// [`batch_evaluate_finish`](Self::batch_evaluate_finish). /// /// # Errors - /// [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Metadata`] if the `metadata` is longer then `u16::MAX - 21`. + /// - [`Error::Protocol`] if the protocol fails and can't be completed. pub fn batch_evaluate_prepare<'a, I: Iterator>>( &self, blinded_elements: I, @@ -646,6 +657,13 @@ where let m = CS::Group::hash_to_scalar::(&context, Mode::Verifiable) .map_err(|_| Error::Metadata)?; let t = self.sk + &m; + + // if t == 0: + if bool::from(CS::Group::is_zero_scalar(t)) { + // raise InverseError + return Err(Error::Protocol); + } + let evaluation_elements = blinded_elements // To make a return type possible, we have to convert to a `fn` pointer, which isn't // possible if we `move` from context. From c9984198db004f71c277f69519aeff3783bb0925 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Mon, 24 Jan 2022 13:20:12 +0100 Subject: [PATCH 4/5] Change element and scalar de/serialization from `GenericArray` to slice --- src/group/elliptic_curve.rs | 4 ++-- src/group/mod.rs | 4 ++-- src/group/ristretto.rs | 9 +++++--- src/serialization.rs | 39 +++++++++++++++++++-------------- src/tests/voprf_test_vectors.rs | 22 ++++++------------- src/voprf.rs | 4 ++-- 6 files changed, 42 insertions(+), 40 deletions(-) diff --git a/src/group/elliptic_curve.rs b/src/group/elliptic_curve.rs index fa9ca3a..3b2b354 100644 --- a/src/group/elliptic_curve.rs +++ b/src/group/elliptic_curve.rs @@ -89,7 +89,7 @@ where result } - fn deserialize_elem(element_bits: &GenericArray) -> Result { + fn deserialize_elem(element_bits: &[u8]) -> Result { PublicKey::::from_sec1_bytes(element_bits) .map(|public_key| public_key.to_projective()) .map_err(|_| Error::Deserialization) @@ -116,7 +116,7 @@ where scalar.into() } - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { + fn deserialize_scalar(scalar_bits: &[u8]) -> Result { SecretKey::::from_be_bytes(scalar_bits) .map(|secret_key| *secret_key.to_nonzero_scalar()) .map_err(|_| Error::Deserialization) diff --git a/src/group/mod.rs b/src/group/mod.rs index 9306a71..7b25d0c 100644 --- a/src/group/mod.rs +++ b/src/group/mod.rs @@ -93,7 +93,7 @@ pub trait Group { /// # Errors /// [`Error::Deserialization`](crate::Error::Deserialization) if the element /// is not a valid point on the group or the identity element. - fn deserialize_elem(element_bits: &GenericArray) -> Result; + fn deserialize_elem(element_bits: &[u8]) -> Result; /// picks a scalar at random fn random_scalar(rng: &mut R) -> Self::Scalar; @@ -117,7 +117,7 @@ pub trait Group { /// # Errors /// [`Error::Deserialization`](crate::Error::Deserialization) if the scalar /// is not a valid point on the group or zero. - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result; + fn deserialize_scalar(scalar_bits: &[u8]) -> Result; } #[cfg(test)] diff --git a/src/group/ristretto.rs b/src/group/ristretto.rs index 6ac91c5..6e8bc15 100644 --- a/src/group/ristretto.rs +++ b/src/group/ristretto.rs @@ -103,7 +103,7 @@ impl Group for Ristretto255 { elem.compress().to_bytes().into() } - fn deserialize_elem(element_bits: &GenericArray) -> Result { + fn deserialize_elem(element_bits: &[u8]) -> Result { CompressedRistretto::from_slice(element_bits) .decompress() .filter(|point| point != &RistrettoPoint::identity()) @@ -141,8 +141,11 @@ impl Group for Ristretto255 { scalar.to_bytes().into() } - fn deserialize_scalar(scalar_bits: &GenericArray) -> Result { - Scalar::from_canonical_bytes((*scalar_bits).into()) + fn deserialize_scalar(scalar_bits: &[u8]) -> Result { + scalar_bits + .try_into() + .ok() + .and_then(Scalar::from_canonical_bytes) .filter(|scalar| scalar != &Scalar::zero()) .ok_or(Error::Deserialization) } diff --git a/src/serialization.rs b/src/serialization.rs index 6761f7d..8795882 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -13,7 +13,7 @@ use core::ops::Add; use digest::core_api::BlockSizeUser; use digest::OutputSizeUser; use generic_array::sequence::Concat; -use generic_array::typenum::{IsLess, IsLessOrEqual, Sum, U256}; +use generic_array::typenum::{IsLess, IsLessOrEqual, Sum, Unsigned, U256}; use generic_array::{ArrayLength, GenericArray}; use crate::{ @@ -46,7 +46,7 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let blind = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; + let blind = deserialize_scalar::(&mut input)?; Ok(Self { blind }) } @@ -80,8 +80,8 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let blind = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; - let blinded_element = CS::Group::deserialize_elem(&deserialize(&mut input)?)?; + let blind = deserialize_scalar::(&mut input)?; + let blinded_element = deserialize_elem::(&mut input)?; Ok(Self { blind, @@ -110,7 +110,7 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let sk = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; + let sk = deserialize_scalar::(&mut input)?; Ok(Self { sk }) } @@ -143,8 +143,8 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let sk = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; - let pk = CS::Group::deserialize_elem(&deserialize(&mut input)?)?; + let sk = deserialize_scalar::(&mut input)?; + let pk = deserialize_elem::(&mut input)?; Ok(Self { sk, pk }) } @@ -178,8 +178,8 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let c_scalar = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; - let s_scalar = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?; + let c_scalar = deserialize_scalar::(&mut input)?; + let s_scalar = deserialize_scalar::(&mut input)?; Ok(Proof { c_scalar, s_scalar }) } @@ -205,7 +205,7 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let value = CS::Group::deserialize_elem(&deserialize(&mut input)?)?; + let value = deserialize_elem::(&mut input)?; Ok(Self(value)) } @@ -231,15 +231,22 @@ where pub fn deserialize(input: &[u8]) -> Result { let mut input = input.iter().copied(); - let value = CS::Group::deserialize_elem(&deserialize(&mut input)?)?; + let value = deserialize_elem::(&mut input)?; Ok(Self(value)) } } -fn deserialize>( - input: &mut impl Iterator, -) -> Result> { - let input = input.by_ref().take(L::USIZE); - GenericArray::from_exact_iter(input).ok_or(Error::Deserialization) +fn deserialize_elem>(input: &mut I) -> Result { + let input = input.by_ref().take(G::ElemLen::USIZE); + GenericArray::<_, G::ElemLen>::from_exact_iter(input) + .ok_or(Error::Deserialization) + .and_then(|bytes| G::deserialize_elem(&bytes)) +} + +fn deserialize_scalar>(input: &mut I) -> Result { + let input = input.by_ref().take(G::ScalarLen::USIZE); + GenericArray::<_, G::ScalarLen>::from_exact_iter(input) + .ok_or(Error::Deserialization) + .and_then(|bytes| G::deserialize_scalar(&bytes)) } diff --git a/src/tests/voprf_test_vectors.rs b/src/tests/voprf_test_vectors.rs index b8d07fd..b3fb3ea 100644 --- a/src/tests/voprf_test_vectors.rs +++ b/src/tests/voprf_test_vectors.rs @@ -13,7 +13,7 @@ use core::ops::Add; use digest::core_api::BlockSizeUser; use digest::OutputSizeUser; use generic_array::typenum::{IsLess, IsLessOrEqual, Sum, U256}; -use generic_array::{ArrayLength, GenericArray}; +use generic_array::ArrayLength; use json::JsonValue; use crate::tests::mock_rng::CycleRng; @@ -183,9 +183,7 @@ where { for parameters in tvs { for i in 0..parameters.input.len() { - let blind = CS::Group::deserialize_scalar(&GenericArray::clone_from_slice( - ¶meters.blind[i], - ))?; + let blind = CS::Group::deserialize_scalar(¶meters.blind[i])?; let client_result = NonVerifiableClient::::deterministic_blind_unchecked( ¶meters.input[i], blind, @@ -212,9 +210,7 @@ where { for parameters in tvs { for i in 0..parameters.input.len() { - let blind = CS::Group::deserialize_scalar(&GenericArray::clone_from_slice( - ¶meters.blind[i], - ))?; + let blind = CS::Group::deserialize_scalar(¶meters.blind[i])?; let client_blind_result = VerifiableClient::::deterministic_blind_unchecked(¶meters.input[i], blind)?; @@ -306,7 +302,7 @@ where for parameters in tvs { for i in 0..parameters.input.len() { let client = NonVerifiableClient::::from_blind(CS::Group::deserialize_scalar( - &GenericArray::clone_from_slice(¶meters.blind[i]), + ¶meters.blind[i], )?); let client_finalize_result = client.finalize( @@ -330,12 +326,8 @@ where let mut clients = vec![]; for i in 0..parameters.input.len() { let client = VerifiableClient::::from_blind_and_element( - CS::Group::deserialize_scalar(&GenericArray::clone_from_slice( - ¶meters.blind[i], - ))?, - CS::Group::deserialize_elem(&GenericArray::clone_from_slice( - ¶meters.blinded_element[i], - ))?, + CS::Group::deserialize_scalar(¶meters.blind[i])?, + CS::Group::deserialize_elem(¶meters.blinded_element[i])?, ); clients.push(client.clone()); } @@ -351,7 +343,7 @@ where &clients, &messages, &Proof::deserialize(¶meters.proof)?, - CS::Group::deserialize_elem(GenericArray::from_slice(¶meters.pksm))?, + CS::Group::deserialize_elem(¶meters.pksm)?, Some(¶meters.info), )?; diff --git a/src/voprf.rs b/src/voprf.rs index 68b98f5..fd8a196 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -447,7 +447,7 @@ where /// [`Error::Deserialization`] if the private key is not a valid point on /// the group or zero. pub fn new_with_key(private_key_bytes: &[u8]) -> Result { - let sk = CS::Group::deserialize_scalar(private_key_bytes.into())?; + let sk = CS::Group::deserialize_scalar(private_key_bytes)?; Ok(Self { sk }) } @@ -531,7 +531,7 @@ where /// [`Error::Deserialization`] if the private key is not a valid point on /// the group or zero. pub fn new_with_key(key: &[u8]) -> Result { - let sk = CS::Group::deserialize_scalar(key.into())?; + let sk = CS::Group::deserialize_scalar(key)?; let pk = CS::Group::base_elem() * &sk; Ok(Self { sk, pk }) } From 95fd30310be489379eef38ebdf74ddd60bb009f6 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Mon, 24 Jan 2022 15:39:38 +0100 Subject: [PATCH 5/5] Customize `serde` serialization --- Cargo.toml | 3 +- src/lib.rs | 3 ++ src/serialization.rs | 50 ++++++++++++++++++++++++++++++ src/voprf.rs | 74 +++++++++++++++++++------------------------- 4 files changed, 86 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1aa7b87..90ec0ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ ristretto255-fiat-u64 = ["curve25519-dalek/fiat_u64_backend", "ristretto255"] ristretto255-simd = ["curve25519-dalek/simd_backend", "ristretto255"] ristretto255-u32 = ["curve25519-dalek/u32_backend", "ristretto255"] ristretto255-u64 = ["curve25519-dalek/u64_backend", "ristretto255"] +serde = ["generic-array/serde", "serde_"] std = ["alloc"] [dependencies] @@ -36,7 +37,7 @@ elliptic-curve = { version = "0.12.0-pre.1", features = [ ] } generic-array = "0.14" rand_core = { version = "0.6", default-features = false } -serde = { version = "1", default-features = false, features = [ +serde_ = { version = "1", package = "serde", default-features = false, features = [ "derive", ], optional = true } sha2 = { version = "0.10", default-features = false, optional = true } diff --git a/src/lib.rs b/src/lib.rs index 44a00e0..35522bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -492,6 +492,9 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; +#[cfg(feature = "serde")] +extern crate serde_ as serde; + mod ciphersuite; mod error; mod group; diff --git a/src/serialization.rs b/src/serialization.rs index 8795882..55606f5 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -250,3 +250,53 @@ fn deserialize_scalar>(input: &mut I) -> Result .ok_or(Error::Deserialization) .and_then(|bytes| G::deserialize_scalar(&bytes)) } + +#[cfg(feature = "serde")] +pub(crate) mod serde { + use core::marker::PhantomData; + + use generic_array::GenericArray; + use serde::de::{Deserializer, Error}; + use serde::ser::Serializer; + use serde::{Deserialize, Serialize}; + + use crate::Group; + + pub(crate) struct Element(PhantomData); + + impl<'de, G: Group> Element { + pub(crate) fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + GenericArray::<_, G::ElemLen>::deserialize(deserializer) + .and_then(|bytes| G::deserialize_elem(&bytes).map_err(D::Error::custom)) + } + + pub(crate) fn serialize(self_: &G::Elem, serializer: S) -> Result + where + S: Serializer, + { + G::serialize_elem(*self_).serialize(serializer) + } + } + + pub(crate) struct Scalar(PhantomData); + + impl<'de, G: Group> Scalar { + pub(crate) fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + GenericArray::<_, G::ScalarLen>::deserialize(deserializer) + .and_then(|bytes| G::deserialize_scalar(&bytes).map_err(D::Error::custom)) + } + + pub(crate) fn serialize(self_: &G::Scalar, serializer: S) -> Result + where + S: Serializer, + { + G::serialize_scalar(*self_).serialize(serializer) + } + } +} diff --git a/src/voprf.rs b/src/voprf.rs index fd8a196..a750b0f 100644 --- a/src/voprf.rs +++ b/src/voprf.rs @@ -20,6 +20,8 @@ use generic_array::GenericArray; use rand_core::{CryptoRng, RngCore}; use subtle::ConstantTimeEq; +#[cfg(feature = "serde")] +use crate::serialization::serde::{Element, Scalar}; use crate::util::{i2osp_2, i2osp_2_array}; use crate::{CipherSuite, Error, Group, Result}; @@ -68,16 +70,14 @@ impl Mode { #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct NonVerifiableClient where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) blind: ::Scalar, } @@ -89,19 +89,16 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>, ::Elem: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize, ::Elem: \ - serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct VerifiableClient where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) blind: ::Scalar, + #[cfg_attr(feature = "serde", serde(with = "Element::"))] pub(crate) blinded_element: ::Elem, } @@ -113,16 +110,14 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct NonVerifiableServer where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) sk: ::Scalar, } @@ -134,19 +129,16 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>, ::Elem: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize, ::Elem: \ - serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct VerifiableServer where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) sk: ::Scalar, + #[cfg_attr(feature = "serde", serde(with = "Element::"))] pub(crate) pk: ::Elem, } @@ -158,17 +150,16 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct Proof where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>, { + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) c_scalar: ::Scalar, + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] pub(crate) s_scalar: ::Scalar, } @@ -180,12 +171,12 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Elem: serde::Deserialize<'de>", - serialize = "::Elem: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] -pub struct BlindedElement(pub(crate) ::Elem) +pub struct BlindedElement( + #[cfg_attr(feature = "serde", serde(with = "Element::"))] + pub(crate) ::Elem, +) where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>; @@ -198,12 +189,12 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Elem: serde::Deserialize<'de>", - serialize = "::Elem: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] -pub struct EvaluationElement(pub(crate) ::Elem) +pub struct EvaluationElement( + #[cfg_attr(feature = "serde", serde(with = "Element::"))] + pub(crate) ::Elem, +) where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>; @@ -834,10 +825,7 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Elem: serde::Deserialize<'de>", - serialize = "::Elem: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] pub struct PreparedEvaluationElement(EvaluationElement) where @@ -851,12 +839,12 @@ where #[cfg_attr( feature = "serde", derive(serde::Deserialize, serde::Serialize), - serde(bound( - deserialize = "::Scalar: serde::Deserialize<'de>", - serialize = "::Scalar: serde::Serialize" - )) + serde(crate = "serde", bound = "") )] -pub struct PreparedTscalar(::Scalar) +pub struct PreparedTscalar( + #[cfg_attr(feature = "serde", serde(with = "Scalar::"))] + ::Scalar, +) where ::OutputSize: IsLess + IsLessOrEqual<::BlockSize>;