Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General Improvements #47

Merged
merged 6 commits into from
Dec 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ jobs:
toolchain:
- stable
- 1.51.0
exclude:
- backend_feature: p256
toolchain: 1.51.0
- backend_feature: ristretto255_u64,p256
toolchain: 1.51.0
name: test
steps:
- name: Checkout sources
Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ num-bigint = { version = "0.4", default-features = false, optional = true }
num-integer = { version = "0.1", default-features = false, optional = true }
num-traits = { version = "0.2", default-features = false, optional = true }
once_cell = { version = "1", default-features = false, optional = true }
p256_ = { package = "p256", version = "0.9", default-features = false, features = [
p256_ = { package = "p256", version = "0.10", default-features = false, features = [
"arithmetic",
"zeroize",
], optional = true }
rand_core = { version = "0.6", default-features = false }
serde = { version = "1", default-features = false, features = [
Expand Down
11 changes: 6 additions & 5 deletions src/errors.rs → src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
// License, Version 2.0 found in the LICENSE-APACHE file in the root directory
// of this source tree.

//! A list of error types which are produced during an execution of the protocol
#[cfg(feature = "std")]
use std::error::Error;
//! Errors which are produced during an execution of the protocol

use displaydoc::Display;

/// [`Result`](core::result::Result) shorthand that uses [`Error`].
pub type Result<T> = core::result::Result<T, Error>;

/// Represents an error in the manipulation of internal cryptographic data
#[derive(Clone, Copy, Debug, Display, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum InternalError {
pub enum Error {
/// Could not parse byte sequence for key
InvalidByteSequence,
/// Could not deserialize element, or deserialized to the identity element
Expand All @@ -38,4 +39,4 @@ pub enum InternalError {
}

#[cfg(feature = "std")]
impl Error for InternalError {}
impl std::error::Error for Error {}
6 changes: 3 additions & 3 deletions src/group/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use generic_array::sequence::Concat;
use generic_array::typenum::{Unsigned, U1, U2};
use generic_array::{ArrayLength, GenericArray};

use crate::errors::InternalError;
use crate::util::i2osp;
use crate::{Error, Result};

// Computes ceil(x / y)
fn div_ceil(x: usize, y: usize) -> usize {
Expand All @@ -37,14 +37,14 @@ pub fn expand_message_xmd<
>(
msg: M,
dst: GenericArray<u8, D>,
) -> Result<GenericArray<u8, L>, InternalError>
) -> Result<GenericArray<u8, L>>
where
<D as Add<U1>>::Output: ArrayLength<u8>,
{
let digest_len = H::OutputSize::USIZE;
let ell = div_ceil(L::USIZE, digest_len);
if ell > 255 {
return Err(InternalError::HashToCurveError);
return Err(Error::HashToCurveError);
}
let dst_prime = dst.concat(i2osp::<U1>(D::USIZE)?);
let z_pad = i2osp::<H::BlockSize>(0)?;
Expand Down
21 changes: 10 additions & 11 deletions src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
use zeroize::Zeroize;

use crate::errors::InternalError;
use crate::{Error, Result};

/// A prime-order subgroup of a base field (EC, prime-order field ...). This
/// subgroup is noted additively — as in the draft RFC — in this trait.
Expand All @@ -43,7 +43,7 @@ pub trait Group:
fn hash_to_curve<H: BlockSizeUser + Digest + FixedOutputReset, D: ArrayLength<u8> + Add<U1>>(
msg: &[u8],
dst: GenericArray<u8, D>,
) -> Result<Self, InternalError>
) -> Result<Self>
where
<D as Add<U1>>::Output: ArrayLength<u8>;

Expand All @@ -56,7 +56,7 @@ pub trait Group:
>(
input: I,
dst: GenericArray<u8, D>,
) -> Result<Self::Scalar, InternalError>
) -> Result<Self::Scalar>
where
<D as Add<U1>>::Output: ArrayLength<u8>;

Expand All @@ -74,16 +74,16 @@ pub trait Group:
/// checking if the scalar is zero.
fn from_scalar_slice_unchecked(
scalar_bits: &GenericArray<u8, Self::ScalarLen>,
) -> Result<Self::Scalar, InternalError>;
) -> Result<Self::Scalar>;

/// Return a scalar from its fixed-length bytes representation. If the
/// scalar is zero, then return an error.
fn from_scalar_slice<'a>(
scalar_bits: impl Into<&'a GenericArray<u8, Self::ScalarLen>>,
) -> Result<Self::Scalar, InternalError> {
) -> Result<Self::Scalar> {
let scalar = Self::from_scalar_slice_unchecked(scalar_bits.into())?;
if scalar.ct_eq(&Self::scalar_zero()).into() {
return Err(InternalError::ZeroScalarError);
return Err(Error::ZeroScalarError);
}
Ok(scalar)
}
Expand All @@ -101,20 +101,19 @@ pub trait Group:
/// Return an element from its fixed-length bytes representation. This is
/// the unchecked version, which does not check for deserializing the
/// identity element
fn from_element_slice_unchecked(
element_bits: &GenericArray<u8, Self::ElemLen>,
) -> Result<Self, InternalError>;
fn from_element_slice_unchecked(element_bits: &GenericArray<u8, Self::ElemLen>)
-> Result<Self>;

/// Return an element from its fixed-length bytes representation. If the
/// element is the identity element, return an error.
fn from_element_slice<'a>(
element_bits: impl Into<&'a GenericArray<u8, Self::ElemLen>>,
) -> Result<Self, InternalError> {
) -> Result<Self> {
let elem = Self::from_element_slice_unchecked(element_bits.into())?;

if Self::ct_eq(&elem, &<Self as Group>::identity()).into() {
// found the identity element
return Err(InternalError::PointError);
return Err(Error::PointError);
}

Ok(elem)
Expand Down
29 changes: 15 additions & 14 deletions src/group/p256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ use num_traits::{One, ToPrimitive, Zero};
use once_cell::unsync::Lazy;
use p256_::elliptic_curve::group::prime::PrimeCurveAffine;
use p256_::elliptic_curve::group::GroupEncoding;
use p256_::elliptic_curve::ops::Reduce;
use p256_::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint};
use p256_::elliptic_curve::Field;
use p256_::{AffinePoint, EncodedPoint, ProjectivePoint};
use rand_core::{CryptoRng, RngCore};
use subtle::{Choice, ConditionallySelectable};

use super::Group;
use crate::errors::InternalError;
use crate::{Error, Result};

// https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-11#section-8.2
// `L: 48`
Expand All @@ -48,7 +49,7 @@ impl Group for ProjectivePoint {
fn hash_to_curve<H: BlockSizeUser + Digest + FixedOutputReset, D: ArrayLength<u8> + Add<U1>>(
msg: &[u8],
dst: GenericArray<u8, D>,
) -> Result<Self, InternalError>
) -> Result<Self>
where
<D as Add<U1>>::Output: ArrayLength<u8>,
{
Expand Down Expand Up @@ -85,15 +86,15 @@ impl Group for ProjectivePoint {
let (q1x, q1y) = hash_to_curve_simple_swu(&uniform_bytes[L::USIZE..], &A, &B, &P, &Z);

// convert to `p256` types
let p0 = AffinePoint::from_encoded_point(&EncodedPoint::from_affine_coordinates(
&q0x, &q0y, false,
let p0 = Option::<AffinePoint>::from(AffinePoint::from_encoded_point(
&EncodedPoint::from_affine_coordinates(&q0x, &q0y, false),
))
.ok_or(InternalError::PointError)?
.ok_or(Error::PointError)?
.to_curve();
let p1 = AffinePoint::from_encoded_point(&EncodedPoint::from_affine_coordinates(
&q1x, &q1y, false,
let p1 = Option::<AffinePoint>::from(AffinePoint::from_encoded_point(
&EncodedPoint::from_affine_coordinates(&q1x, &q1y, false),
))
.ok_or(InternalError::PointError)?;
.ok_or(Error::PointError)?;

Ok(p0 + p1)
}
Expand All @@ -107,7 +108,7 @@ impl Group for ProjectivePoint {
>(
input: I,
dst: GenericArray<u8, D>,
) -> Result<Self::Scalar, InternalError>
) -> Result<Self::Scalar>
where
<D as Add<U1>>::Output: ArrayLength<u8>,
{
Expand All @@ -132,7 +133,7 @@ impl Group for ProjectivePoint {
let mut result = GenericArray::default();
result[..bytes.len()].copy_from_slice(&bytes);

Ok(p256_::Scalar::from_bytes_reduced(&result))
Ok(p256_::Scalar::from_be_bytes_reduced(result))
}

type ElemLen = U33;
Expand All @@ -141,8 +142,8 @@ impl Group for ProjectivePoint {

fn from_scalar_slice_unchecked(
scalar_bits: &GenericArray<u8, Self::ScalarLen>,
) -> Result<Self::Scalar, InternalError> {
Ok(Self::Scalar::from_bytes_reduced(scalar_bits))
) -> Result<Self::Scalar> {
Ok(Self::Scalar::from_be_bytes_reduced(*scalar_bits))
}

fn random_nonzero_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar {
Expand All @@ -159,8 +160,8 @@ impl Group for ProjectivePoint {

fn from_element_slice_unchecked(
element_bits: &GenericArray<u8, Self::ElemLen>,
) -> Result<Self, InternalError> {
Option::from(Self::from_bytes(element_bits)).ok_or(InternalError::PointError)
) -> Result<Self> {
Option::from(Self::from_bytes(element_bits)).ok_or(Error::PointError)
}

fn to_arr(&self) -> GenericArray<u8, Self::ElemLen> {
Expand Down
16 changes: 8 additions & 8 deletions src/group/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use generic_array::{ArrayLength, GenericArray};
use rand_core::{CryptoRng, RngCore};

use super::Group;
use crate::errors::InternalError;
use crate::{Error, Result};

// `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428
#[cfg(feature = "ristretto255")]
Expand All @@ -32,7 +32,7 @@ impl Group for RistrettoPoint {
fn hash_to_curve<H: BlockSizeUser + Digest + FixedOutputReset, D: ArrayLength<u8> + Add<U1>>(
msg: &[u8],
dst: GenericArray<u8, D>,
) -> Result<Self, InternalError>
) -> Result<Self>
where
<D as Add<U1>>::Output: ArrayLength<u8>,
{
Expand All @@ -42,7 +42,7 @@ impl Group for RistrettoPoint {
uniform_bytes
.as_slice()
.try_into()
.map_err(|_| InternalError::HashToCurveError)?,
.map_err(|_| Error::HashToCurveError)?,
))
}

Expand All @@ -56,7 +56,7 @@ impl Group for RistrettoPoint {
>(
input: I,
dst: GenericArray<u8, D>,
) -> Result<Self::Scalar, InternalError>
) -> Result<Self::Scalar>
where
<D as Add<U1>>::Output: ArrayLength<u8>,
{
Expand All @@ -66,15 +66,15 @@ impl Group for RistrettoPoint {
uniform_bytes
.as_slice()
.try_into()
.map_err(|_| InternalError::HashToCurveError)?,
.map_err(|_| Error::HashToCurveError)?,
))
}

type Scalar = Scalar;
type ScalarLen = U32;
fn from_scalar_slice_unchecked(
scalar_bits: &GenericArray<u8, Self::ScalarLen>,
) -> Result<Self::Scalar, InternalError> {
) -> Result<Self::Scalar> {
Ok(Scalar::from_bytes_mod_order(*scalar_bits.as_ref()))
}

Expand Down Expand Up @@ -104,10 +104,10 @@ impl Group for RistrettoPoint {
type ElemLen = U32;
fn from_element_slice_unchecked(
element_bits: &GenericArray<u8, Self::ElemLen>,
) -> Result<Self, InternalError> {
) -> Result<Self> {
CompressedRistretto::from_slice(element_bits)
.decompress()
.ok_or(InternalError::PointError)
.ok_or(Error::PointError)
}
// serialization of a group element
fn to_arr(&self) -> GenericArray<u8, Self::ElemLen> {
Expand Down
13 changes: 6 additions & 7 deletions src/group/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

//! Includes a series of tests for the group implementations

use crate::errors::InternalError;
use crate::group::Group;
use crate::{Error, Group, Result};

// Test that the deserialization of a group element should throw an error if the
// identity element can be deserialized properly

#[test]
fn test_group_properties() -> Result<(), InternalError> {
fn test_group_properties() -> Result<()> {
#[cfg(feature = "ristretto255")]
{
use curve25519_dalek::ristretto::RistrettoPoint;
Expand All @@ -35,19 +34,19 @@ fn test_group_properties() -> Result<(), InternalError> {
}

// Checks that the identity element cannot be deserialized
fn test_identity_element_error<G: Group>() -> Result<(), InternalError> {
fn test_identity_element_error<G: Group>() -> Result<()> {
let identity = G::identity();
let result = G::from_element_slice(&identity.to_arr());
assert!(matches!(result, Err(InternalError::PointError)));
assert!(matches!(result, Err(Error::PointError)));

Ok(())
}

// Checks that the zero scalar cannot be deserialized
fn test_zero_scalar_error<G: Group>() -> Result<(), InternalError> {
fn test_zero_scalar_error<G: Group>() -> Result<()> {
let zero_scalar = G::scalar_zero();
let result = G::from_scalar_slice(&G::scalar_as_bytes(zero_scalar));
assert!(matches!(result, Err(InternalError::ZeroScalarError)));
assert!(matches!(result, Err(Error::ZeroScalarError)));

Ok(())
}
17 changes: 11 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@
//! VOPRF evaluations.
//!
//! - The `p256` feature enables using p256 as the underlying group for the
//! [Group](group::Group) choice. Note that this is currently an experimental
//! feature ⚠️, and is not yet ready for production use.
//! [Group] choice and increases the MSRV to 1.56. Note that this is currently
//! an experimental feature ⚠️, and is not yet ready for production use.
//!
//! - The `serde` feature, enabled by default, provides convenience functions
//! for serializing and deserializing with [serde](https://serde.rs/).
Expand Down Expand Up @@ -512,17 +512,22 @@ extern crate std;
mod util;
#[macro_use]
mod serialization;
pub mod errors;
pub mod group;
mod error;
mod group;
mod voprf;

#[cfg(test)]
mod tests;

// Exports

pub use crate::error::{Error, Result};
pub use crate::group::Group;
#[cfg(feature = "alloc")]
pub use crate::voprf::VerifiableServerBatchEvaluateResult;
pub use crate::voprf::{
BlindedElement, EvaluationElement, NonVerifiableClient, NonVerifiableClientBlindResult,
NonVerifiableServer, NonVerifiableServerEvaluateResult, VerifiableClient,
VerifiableClientBlindResult, VerifiableServer, VerifiableServerEvaluateResult,
NonVerifiableServer, NonVerifiableServerEvaluateResult, Proof, VerifiableClient,
VerifiableClientBatchFinalizeResult, VerifiableClientBlindResult, VerifiableServer,
VerifiableServerEvaluateResult,
};
Loading