Skip to content

Commit

Permalink
Uses conditional compilation to make sure postprocessing is only avai…
Browse files Browse the repository at this point in the history
…lable in tests

See
rust-lang/rust#64010
  • Loading branch information
huitseeker authored and kevinlewi committed Nov 5, 2020
1 parent 8d29f72 commit 103607c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 72 deletions.
60 changes: 22 additions & 38 deletions src/opaque.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,37 +561,25 @@ impl<CS: CipherSuite> ClientRegistration<CS> {
&Vec::new(),
password,
blinding_factor_rng,
)
}

/// Same as ClientRegistration::start, but also accepts a username and server name as input
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
user_name: &[u8],
server_name: &[u8],
password: &[u8],
blinding_factor_rng: &mut R,
) -> Result<(RegisterFirstMessage<CS::Group>, Self), ProtocolError> {
Self::start_with_user_and_server_name_and_postprocessing(
user_name,
server_name,
password,
blinding_factor_rng,
#[cfg(test)]
std::convert::identity,
)
}

/// Same as ClientRegistration::start, but also accepts a username and server name as input as well as
/// an optional postprocessing function for the blinding factor
pub fn start_with_user_and_server_name_and_postprocessing<R: RngCore + CryptoRng>(
/// Same as ClientRegistration::start, but also accepts a username and
/// server name as input
/// as well as an optional postprocessing function for the blinding factor(used in tests)
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
user_name: &[u8],
server_name: &[u8],
password: &[u8],
blinding_factor_rng: &mut R,
postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
#[cfg(test)] postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
) -> Result<(RegisterFirstMessage<CS::Group>, Self), ProtocolError> {
let (token, alpha) = oprf::blind_with_postprocessing::<R, CS::Group>(
let (token, alpha) = oprf::blind::<R, CS::Group>(
&password,
blinding_factor_rng,
#[cfg(test)]
postprocess,
)?;

Expand Down Expand Up @@ -1018,35 +1006,31 @@ impl<CS: CipherSuite> ClientLogin<CS> {
password: &[u8],
rng: &mut R,
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
Self::start_with_user_and_server_name(&Vec::new(), &Vec::new(), password, rng)
}

/// Same as start, but allows the user to supply a username and server name
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
user_name: &[u8],
server_name: &[u8],
password: &[u8],
rng: &mut R,
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
Self::start_with_user_and_server_name_and_postprocessing(
user_name,
server_name,
Self::start_with_user_and_server_name(
&Vec::new(),
&Vec::new(),
password,
rng,
#[cfg(test)]
std::convert::identity,
)
}

/// Same as start, but allows the user to supply a username and server name and postprocessing function
pub fn start_with_user_and_server_name_and_postprocessing<R: RngCore + CryptoRng>(
/// Same as start, but allows the user to supply a username and server name
/// and, in tests, a postprocessing function
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
user_name: &[u8],
server_name: &[u8],
password: &[u8],
rng: &mut R,
postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
#[cfg(test)] postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
let (token, alpha) =
oprf::blind_with_postprocessing::<R, CS::Group>(&password, rng, postprocess)?;
let (token, alpha) = oprf::blind::<R, CS::Group>(
&password,
rng,
#[cfg(test)]
postprocess,
)?;

let (ke1_state, ke1_message) = CS::KeyExchange::generate_ke1(alpha.to_arr().to_vec(), rng)?;

Expand Down
44 changes: 27 additions & 17 deletions src/oprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ static STR_VOPRF: &[u8] = b"VOPRF05";
/// message is sent from the client (who holds the input) to the server (who holds the OPRF key).
/// The client can also pass in an optional "pepper" string to be mixed in with the input through
/// an HKDF computation.
pub(crate) fn blind_with_postprocessing<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
pub(crate) fn blind<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
input: &[u8],
blinding_factor_rng: &mut R,
postprocess: fn(G::Scalar) -> G::Scalar,
#[cfg(test)] postprocess: fn(G::Scalar) -> G::Scalar,
) -> Result<(Token<G>, G), InternalPakeError> {
let mapped_point = G::map_to_curve(input, Some(STR_VOPRF)); // TODO: add contextString from RFC
let blinding_factor = G::random_scalar(blinding_factor_rng);
#[cfg(test)]
let blind = postprocess(blinding_factor);
#[cfg(not(test))]
let blind = blinding_factor;

let blind_token = mapped_point * &blind;
Ok((
Token {
Expand Down Expand Up @@ -60,23 +64,34 @@ pub(crate) fn unblind_and_finalize<G: Group, H: Hash>(
Ok(prk)
}

// Benchmarking shims
////////////////////////
// Benchmarking shims //
////////////////////////

#[cfg(feature = "bench")]
#[doc(hidden)]
#[inline]
pub fn blind_shim<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
input: &[u8],
blinding_factor_rng: &mut R,
) -> Result<(Token<G>, G), InternalPakeError> {
blind_with_postprocessing(input, blinding_factor_rng, std::convert::identity)
blind(
input,
blinding_factor_rng,
#[cfg(test)]
std::convert::identity,
)
}

#[cfg(feature = "bench")]
#[doc(hidden)]
#[inline]
pub fn evaluate_shim<G: Group>(point: G, oprf_key: &G::Scalar) -> Result<G, InternalPakeError> {
evaluate(point, oprf_key)
}

#[cfg(feature = "bench")]
#[doc(hidden)]
#[inline]
pub fn unblind_and_finalize_shim<G: Group, H: Hash>(
token: &Token<G>,
Expand All @@ -85,8 +100,10 @@ pub fn unblind_and_finalize_shim<G: Group, H: Hash>(
unblind_and_finalize::<G, H>(token, point)
}

// Tests
// =====
///////////
// Tests //
// ===== //
///////////

#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -117,11 +134,8 @@ mod tests {
fn oprf_retrieval() -> Result<(), InternalPakeError> {
let input = b"hunter2";
let mut rng = OsRng;
let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>(
&input[..],
&mut rng,
std::convert::identity,
)?;
let (token, alpha) =
blind::<_, RistrettoPoint>(&input[..], &mut rng, std::convert::identity)?;
let oprf_key_bytes = arr![
u8; 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32,
Expand All @@ -139,12 +153,8 @@ mod tests {
let mut rng = OsRng;
let mut input = vec![0u8; 64];
rng.fill_bytes(&mut input);
let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>(
&input,
&mut rng,
std::convert::identity,
)
.unwrap();
let (token, alpha) =
blind::<_, RistrettoPoint>(&input, &mut rng, std::convert::identity).unwrap();
let res = unblind_and_finalize::<RistrettoPoint, sha2::Sha256>(&token, alpha).unwrap();

let (hashed_input, _) = Hkdf::<Sha512>::extract(Some(STR_VOPRF), &input);
Expand Down
36 changes: 19 additions & 17 deletions src/tests/opaque_ke_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ where
id_s,
password,
&mut blinding_factor_registration_rng,
std::convert::identity,
)
.unwrap();
let r1_bytes = r1.serialize().to_vec();
Expand Down Expand Up @@ -292,6 +293,7 @@ where
id_s,
password,
&mut client_login_start_rng,
std::convert::identity,
)
.unwrap();
let l1_bytes = l1.serialize().to_vec();
Expand Down Expand Up @@ -363,14 +365,15 @@ fn postprocess_blinding_factor<G: Group>(_: G::Scalar) -> G::Scalar {
fn test_r1() -> Result<(), PakeError> {
let parameters = populate_test_vectors(&serde_json::from_str(TEST_VECTOR).unwrap());
let mut rng = OsRng;
let (r1, client_registration) = ClientRegistration::<X255193dhNoSlowHash>::start_with_user_and_server_name_and_postprocessing(
&parameters.id_u,
&parameters.id_s,
&parameters.password,
&mut rng,
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
)
.unwrap();
let (r1, client_registration) =
ClientRegistration::<X255193dhNoSlowHash>::start_with_user_and_server_name(
&parameters.id_u,
&parameters.id_s,
&parameters.password,
&mut rng,
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
)
.unwrap();
assert_eq!(hex::encode(&parameters.r1), hex::encode(r1.serialize()));
assert_eq!(
hex::encode(&parameters.client_registration_state),
Expand Down Expand Up @@ -453,15 +456,14 @@ fn test_l1() -> Result<(), PakeError> {
]
.concat();
let mut client_login_start_rng = CycleRng::new(client_login_start);
let (l1, client_login) =
ClientLogin::<X255193dhNoSlowHash>::start_with_user_and_server_name_and_postprocessing(
&parameters.id_u,
&parameters.id_s,
&parameters.password,
&mut client_login_start_rng,
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
)
.unwrap();
let (l1, client_login) = ClientLogin::<X255193dhNoSlowHash>::start_with_user_and_server_name(
&parameters.id_u,
&parameters.id_s,
&parameters.password,
&mut client_login_start_rng,
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
)
.unwrap();
assert_eq!(hex::encode(&parameters.l1), hex::encode(l1.serialize()));
assert_eq!(
hex::encode(&parameters.client_login_state),
Expand Down

0 comments on commit 103607c

Please sign in to comment.