Skip to content

Commit

Permalink
General improvements (#56)
Browse files Browse the repository at this point in the history
* Apply Rust traits to all public types and other improvements

* Move methods into appropriate section

* Check for zero scalars

* Change element and scalar de/serialization from `GenericArray` to slice

* Customize `serde` serialization
  • Loading branch information
daxpedda committed Jan 28, 2022
1 parent b01b8ed commit b59b359
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 188 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ristretto255-fiat-u64 = ["curve25519-dalek/fiat_u64_backend", "ristretto255"]
ristretto255-simd = ["curve25519-dalek/simd_backend", "ristretto255"]
ristretto255-u32 = ["curve25519-dalek/u32_backend", "ristretto255"]
ristretto255-u64 = ["curve25519-dalek/u64_backend", "ristretto255"]
serde = ["generic-array/serde", "serde_"]
std = ["alloc"]

[dependencies]
Expand All @@ -36,7 +37,7 @@ elliptic-curve = { version = "0.12.0-pre.1", features = [
] }
generic-array = "0.14"
rand_core = { version = "0.6", default-features = false }
serde = { version = "1", default-features = false, features = [
serde_ = { version = "1", package = "serde", default-features = false, features = [
"derive",
], optional = true }
sha2 = { version = "0.10", default-features = false, optional = true }
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum Error {
ProofVerification,
/// Size of seed is longer then [`u16::MAX`].
Seed,
/// The protocol has failed and can't be completed.
Protocol,
}

/// Only used to implement [`Group`](crate::Group).
Expand Down
8 changes: 6 additions & 2 deletions src/group/elliptic_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
result
}

fn deserialize_elem(element_bits: &GenericArray<u8, Self::ElemLen>) -> Result<Self::Elem> {
fn deserialize_elem(element_bits: &[u8]) -> Result<Self::Elem> {
PublicKey::<Self>::from_sec1_bytes(element_bits)
.map(|public_key| public_key.to_projective())
.map_err(|_| Error::Deserialization)
Expand All @@ -103,6 +103,10 @@ where
Option::from(scalar.invert()).unwrap()
}

fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice {
scalar.is_zero()
}

#[cfg(test)]
fn zero_scalar() -> Self::Scalar {
Scalar::<Self>::zero()
Expand All @@ -112,7 +116,7 @@ where
scalar.into()
}

fn deserialize_scalar(scalar_bits: &GenericArray<u8, Self::ScalarLen>) -> Result<Self::Scalar> {
fn deserialize_scalar(scalar_bits: &[u8]) -> Result<Self::Scalar> {
SecretKey::<Self>::from_be_bytes(scalar_bits)
.map(|secret_key| *secret_key.to_nonzero_scalar())
.map_err(|_| Error::Deserialization)
Expand Down
9 changes: 6 additions & 3 deletions src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use generic_array::{ArrayLength, GenericArray};
use rand_core::{CryptoRng, RngCore};
#[cfg(feature = "ristretto255")]
pub use ristretto::Ristretto255;
use subtle::ConstantTimeEq;
use subtle::{Choice, ConstantTimeEq};
use zeroize::Zeroize;

use crate::voprf::Mode;
Expand Down Expand Up @@ -93,14 +93,17 @@ pub trait Group {
/// # Errors
/// [`Error::Deserialization`](crate::Error::Deserialization) if the element
/// is not a valid point on the group or the identity element.
fn deserialize_elem(element_bits: &GenericArray<u8, Self::ElemLen>) -> Result<Self::Elem>;
fn deserialize_elem(element_bits: &[u8]) -> Result<Self::Elem>;

/// picks a scalar at random
fn random_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar;

/// The multiplicative inverse of this scalar
fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar;

/// Returns `true` if the scalar is zero.
fn is_zero_scalar(scalar: Self::Scalar) -> Choice;

/// Returns the scalar representing zero
#[cfg(test)]
fn zero_scalar() -> Self::Scalar;
Expand All @@ -114,7 +117,7 @@ pub trait Group {
/// # Errors
/// [`Error::Deserialization`](crate::Error::Deserialization) if the scalar
/// is not a valid point on the group or zero.
fn deserialize_scalar(scalar_bits: &GenericArray<u8, Self::ScalarLen>) -> Result<Self::Scalar>;
fn deserialize_scalar(scalar_bits: &[u8]) -> Result<Self::Scalar>;
}

#[cfg(test)]
Expand Down
17 changes: 14 additions & 3 deletions src/group/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ use generic_array::sequence::Concat;
use generic_array::typenum::{IsLess, IsLessOrEqual, U256, U32, U64};
use generic_array::GenericArray;
use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;

use super::{Group, STR_HASH_TO_GROUP, STR_HASH_TO_SCALAR};
use crate::voprf::{self, Mode};
use crate::{CipherSuite, Error, InternalError, Result};

/// [`Group`] implementation for Ristretto255.
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
// `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428
#[cfg(feature = "ristretto255")]
pub struct Ristretto255;

#[cfg(feature = "ristretto255-ciphersuite")]
Expand Down Expand Up @@ -99,7 +103,7 @@ impl Group for Ristretto255 {
elem.compress().to_bytes().into()
}

fn deserialize_elem(element_bits: &GenericArray<u8, Self::ElemLen>) -> Result<Self::Elem> {
fn deserialize_elem(element_bits: &[u8]) -> Result<Self::Elem> {
CompressedRistretto::from_slice(element_bits)
.decompress()
.filter(|point| point != &RistrettoPoint::identity())
Expand All @@ -124,6 +128,10 @@ impl Group for Ristretto255 {
scalar.invert()
}

fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice {
scalar.ct_eq(&Scalar::zero())
}

#[cfg(test)]
fn zero_scalar() -> Self::Scalar {
Scalar::zero()
Expand All @@ -133,8 +141,11 @@ impl Group for Ristretto255 {
scalar.to_bytes().into()
}

fn deserialize_scalar(scalar_bits: &GenericArray<u8, Self::ScalarLen>) -> Result<Self::Scalar> {
Scalar::from_canonical_bytes((*scalar_bits).into())
fn deserialize_scalar(scalar_bits: &[u8]) -> Result<Self::Scalar> {
scalar_bits
.try_into()
.ok()
.and_then(Scalar::from_canonical_bytes)
.filter(|scalar| scalar != &Scalar::zero())
.ok_or(Error::Deserialization)
}
Expand Down
28 changes: 18 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@
//!
//! In the second step, the server takes as input the message from
//! [NonVerifiableClient::blind] (a [BlindedElement]), and runs
//! [NonVerifiableServer::evaluate] to produce a
//! [NonVerifiableServerEvaluateResult], which consists of an
//! [EvaluationElement] to be sent to the client.
//! [NonVerifiableServer::evaluate] to produce [EvaluationElement] to be sent to
//! the client.
//!
//! ```
//! # #[cfg(feature = "ristretto255")]
Expand Down Expand Up @@ -135,13 +134,13 @@
//! # use voprf::NonVerifiableServer;
//! # let mut server_rng = OsRng;
//! # let server = NonVerifiableServer::<CipherSuite>::new(&mut server_rng);
//! # let server_evaluate_result = server.evaluate(
//! # let message = server.evaluate(
//! # &client_blind_result.message,
//! # None,
//! # ).expect("Unable to perform server evaluate");
//! let client_finalize_result = client_blind_result
//! .state
//! .finalize(b"input", &server_evaluate_result.message, None)
//! .finalize(b"input", &message, None)
//! .expect("Unable to perform client finalization");
//!
//! println!("VOPRF output: {:?}", client_finalize_result.to_vec());
Expand Down Expand Up @@ -479,7 +478,12 @@

#![deny(unsafe_code)]
#![no_std]
#![warn(clippy::cargo, clippy::missing_errors_doc, missing_docs)]
#![warn(
clippy::cargo,
clippy::missing_errors_doc,
missing_debug_implementations,
missing_docs
)]
#![allow(clippy::multiple_crate_versions)]

#[cfg(any(feature = "alloc", test))]
Expand All @@ -488,6 +492,9 @@ extern crate alloc;
#[cfg(feature = "std")]
extern crate std;

#[cfg(feature = "serde")]
extern crate serde_ as serde;

mod ciphersuite;
mod error;
mod group;
Expand All @@ -513,8 +520,9 @@ pub use crate::serialization::{
pub use crate::voprf::VerifiableServerBatchEvaluateResult;
pub use crate::voprf::{
BlindedElement, EvaluationElement, Mode, NonVerifiableClient, NonVerifiableClientBlindResult,
NonVerifiableServer, NonVerifiableServerEvaluateResult, PreparedEvaluationElement,
PreparedTscalar, Proof, VerifiableClient, VerifiableClientBatchFinalizeResult,
VerifiableClientBlindResult, VerifiableServer, VerifiableServerBatchEvaluateFinishResult,
VerifiableServerBatchEvaluatePrepareResult, VerifiableServerEvaluateResult,
NonVerifiableServer, PreparedEvaluationElement, PreparedTscalar, Proof, VerifiableClient,
VerifiableClientBatchFinalizeResult, VerifiableClientBlindResult, VerifiableServer,
VerifiableServerBatchEvaluateFinishResult, VerifiableServerBatchEvaluateFinishedMessages,
VerifiableServerBatchEvaluatePrepareResult,
VerifiableServerBatchEvaluatePreparedEvaluationElements, VerifiableServerEvaluateResult,
};
89 changes: 73 additions & 16 deletions src/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use core::ops::Add;
use digest::core_api::BlockSizeUser;
use digest::OutputSizeUser;
use generic_array::sequence::Concat;
use generic_array::typenum::{IsLess, IsLessOrEqual, Sum, U256};
use generic_array::typenum::{IsLess, IsLessOrEqual, Sum, Unsigned, U256};
use generic_array::{ArrayLength, GenericArray};

use crate::{
Expand Down Expand Up @@ -46,7 +46,7 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let blind = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let blind = deserialize_scalar::<CS::Group, _>(&mut input)?;

Ok(Self { blind })
}
Expand Down Expand Up @@ -80,8 +80,8 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let blind = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let blinded_element = CS::Group::deserialize_elem(&deserialize(&mut input)?)?;
let blind = deserialize_scalar::<CS::Group, _>(&mut input)?;
let blinded_element = deserialize_elem::<CS::Group, _>(&mut input)?;

Ok(Self {
blind,
Expand Down Expand Up @@ -110,7 +110,7 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let sk = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let sk = deserialize_scalar::<CS::Group, _>(&mut input)?;

Ok(Self { sk })
}
Expand Down Expand Up @@ -143,8 +143,8 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let sk = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let pk = CS::Group::deserialize_elem(&deserialize(&mut input)?)?;
let sk = deserialize_scalar::<CS::Group, _>(&mut input)?;
let pk = deserialize_elem::<CS::Group, _>(&mut input)?;

Ok(Self { sk, pk })
}
Expand Down Expand Up @@ -178,8 +178,8 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let c_scalar = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let s_scalar = CS::Group::deserialize_scalar(&deserialize(&mut input)?)?;
let c_scalar = deserialize_scalar::<CS::Group, _>(&mut input)?;
let s_scalar = deserialize_scalar::<CS::Group, _>(&mut input)?;

Ok(Proof { c_scalar, s_scalar })
}
Expand All @@ -205,7 +205,7 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let value = CS::Group::deserialize_elem(&deserialize(&mut input)?)?;
let value = deserialize_elem::<CS::Group, _>(&mut input)?;

Ok(Self(value))
}
Expand All @@ -231,15 +231,72 @@ where
pub fn deserialize(input: &[u8]) -> Result<Self> {
let mut input = input.iter().copied();

let value = CS::Group::deserialize_elem(&deserialize(&mut input)?)?;
let value = deserialize_elem::<CS::Group, _>(&mut input)?;

Ok(Self(value))
}
}

fn deserialize<L: ArrayLength<u8>>(
input: &mut impl Iterator<Item = u8>,
) -> Result<GenericArray<u8, L>> {
let input = input.by_ref().take(L::USIZE);
GenericArray::from_exact_iter(input).ok_or(Error::Deserialization)
fn deserialize_elem<G: Group, I: Iterator<Item = u8>>(input: &mut I) -> Result<G::Elem> {
let input = input.by_ref().take(G::ElemLen::USIZE);
GenericArray::<_, G::ElemLen>::from_exact_iter(input)
.ok_or(Error::Deserialization)
.and_then(|bytes| G::deserialize_elem(&bytes))
}

fn deserialize_scalar<G: Group, I: Iterator<Item = u8>>(input: &mut I) -> Result<G::Scalar> {
let input = input.by_ref().take(G::ScalarLen::USIZE);
GenericArray::<_, G::ScalarLen>::from_exact_iter(input)
.ok_or(Error::Deserialization)
.and_then(|bytes| G::deserialize_scalar(&bytes))
}

#[cfg(feature = "serde")]
pub(crate) mod serde {
use core::marker::PhantomData;

use generic_array::GenericArray;
use serde::de::{Deserializer, Error};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};

use crate::Group;

pub(crate) struct Element<G: Group>(PhantomData<G>);

impl<'de, G: Group> Element<G> {
pub(crate) fn deserialize<D>(deserializer: D) -> Result<G::Elem, D::Error>
where
D: Deserializer<'de>,
{
GenericArray::<_, G::ElemLen>::deserialize(deserializer)
.and_then(|bytes| G::deserialize_elem(&bytes).map_err(D::Error::custom))
}

pub(crate) fn serialize<S>(self_: &G::Elem, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
G::serialize_elem(*self_).serialize(serializer)
}
}

pub(crate) struct Scalar<G: Group>(PhantomData<G>);

impl<'de, G: Group> Scalar<G> {
pub(crate) fn deserialize<D>(deserializer: D) -> Result<G::Scalar, D::Error>
where
D: Deserializer<'de>,
{
GenericArray::<_, G::ScalarLen>::deserialize(deserializer)
.and_then(|bytes| G::deserialize_scalar(&bytes).map_err(D::Error::custom))
}

pub(crate) fn serialize<S>(self_: &G::Scalar, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
G::serialize_scalar(*self_).serialize(serializer)
}
}
}
Loading

0 comments on commit b59b359

Please sign in to comment.