Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove allocation in slh-dsa #860

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions slh-dsa/src/hashes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ pub(crate) trait HashSuite: Sized + Clone + Debug + PartialEq + Eq {
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N>;

/// Hashes a message using a given randomizer
fn h_msg(
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M>;

/// PRF that is used to generate the secret values in WOTS+ and FORS private keys.
Expand Down Expand Up @@ -76,7 +76,7 @@ mod tests {
let opt_rand = Array::<u8, H::N>::from_fn(|_| 1);
let msg = [2u8; 32];

let result = H::prf_msg(&sk_prf, &opt_rand, msg);
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand All @@ -87,7 +87,7 @@ mod tests {
let pk_root = Array::<u8, H::N>::from_fn(|_| 2);
let msg = [3u8; 32];

let result = H::h_msg(&rand, &pk_seed, &pk_root, msg);
let result = H::h_msg(&rand, &pk_seed, &pk_root, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand Down
18 changes: 10 additions & 8 deletions slh-dsa/src/hashes/sha2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut mac = Hmac::<Sha256>::new_from_slice(sk_prf.as_ref()).unwrap();
mac.update(opt_rand.as_slice());
mac.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| mac.update(msg_part.as_ref()));
let result = mac.finalize().into_bytes();
Array::clone_from_slice(&result[..Self::N::USIZE])
}
Expand All @@ -73,13 +74,13 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut h = Sha256::new();
h.update(rand);
h.update(pk_seed);
h.update(pk_root);
h.update(msg.as_ref());
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
let result = Array(h.finalize().into());
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
mgf1::<Sha256, Self::M>(&seed)
Expand Down Expand Up @@ -220,11 +221,12 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut mac = Hmac::<Sha512>::new_from_slice(sk_prf.as_ref()).unwrap();
mac.update(opt_rand.as_slice());
mac.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| mac.update(msg_part.as_ref()));
let result = mac.finalize().into_bytes();
Array::clone_from_slice(&result[..Self::N::USIZE])
}
Expand All @@ -233,13 +235,13 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut h = Sha512::new();
h.update(rand);
h.update(pk_seed);
h.update(pk_root);
h.update(msg.as_ref());
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
let result = Array(h.finalize().into());
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
mgf1::<Sha512, Self::M>(&seed)
Expand Down
12 changes: 7 additions & 5 deletions slh-dsa/src/hashes/shake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut hasher = Shake256::default();
hasher.update(sk_prf.as_ref());
hasher.update(opt_rand.as_slice());
hasher.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
let mut output = Array::<u8, Self::N>::default();
hasher.finalize_xof_into(&mut output);
output
Expand All @@ -49,13 +50,14 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut hasher = Shake256::default();
hasher.update(rand.as_slice());
hasher.update(pk_seed.as_ref());
hasher.update(pk_root.as_ref());
hasher.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
let mut output = Array::<u8, Self::M>::default();
hasher.finalize_xof_into(&mut output);
output
Expand Down Expand Up @@ -267,7 +269,7 @@ mod tests {

let expected = hex!("bc5c062307df0a41aeeae19ad655f7b2");

let result = H::prf_msg(&sk_prf, &opt_rand, msg);
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand Down
5 changes: 2 additions & 3 deletions slh-dsa/src/signing_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<P: ParameterSet> SigningKey<P> {
/// Implements [slh_sign_internal] as defined in FIPS-205.
/// Published for KAT validation purposes but not intended for general use.
/// opt_rand must be a P::N length slice, panics otherwise.
pub fn slh_sign_internal(&self, msg: &[u8], opt_rand: Option<&[u8]>) -> Signature<P> {
pub fn slh_sign_internal(&self, msg: &[&[u8]], opt_rand: Option<&[u8]>) -> Signature<P> {
tarcieri marked this conversation as resolved.
Show resolved Hide resolved
let rand = opt_rand
.unwrap_or(&self.verifying_key.pk_seed.0)
.try_into()
Expand Down Expand Up @@ -142,8 +142,7 @@ impl<P: ParameterSet> SigningKey<P> {
let ctx_len = u8::try_from(ctx.len()).map_err(|_| Error::new())?;
let ctx_len_bytes = ctx_len.to_be_bytes();

// TODO - figure out what to do about this allocation. Maybe pass a chained iterator to slh_sign_internal?
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg].concat();
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg];
Ok(self.slh_sign_internal(&ctx_msg, opt_rand))
}

Expand Down
9 changes: 6 additions & 3 deletions slh-dsa/src/verifying_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ impl<P: ParameterSet + VerifyingKeyLen> VerifyingKey<P> {
/// Verify a raw message (without context).
/// Implements [slh_verify_internal] as defined in FIPS-205.
/// Published for KAT validation purposes but not intended for general use.
pub fn slh_verify_internal(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
pub fn slh_verify_internal(
&self,
msg: &[&[u8]],
signature: &Signature<P>,
) -> Result<(), Error> {
let pk_seed = &self.pk_seed;
let randomizer = &signature.randomizer;
let fors_sig = &signature.fors_sig;
Expand Down Expand Up @@ -79,8 +83,7 @@ impl<P: ParameterSet + VerifyingKeyLen> VerifyingKey<P> {
let ctx_len = u8::try_from(ctx.len()).map_err(|_| Error::new())?;
let ctx_len_bytes = ctx_len.to_be_bytes();

// TODO - figure out what to do about this allocation. Maybe pass a chained iterator to slh_sign_internal?
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg].concat();
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg];
self.slh_verify_internal(&ctx_msg, signature) // TODO - context processing
}

Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/acvp_sig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro_rules! parameter_case {
.additionalRandomness
.as_ref()
.map(|x| x.data.as_slice());
let sig = sk.slh_sign_internal($test_case.message.data.as_slice(), opt_rand);
let sig = sk.slh_sign_internal(&[$test_case.message.data.as_slice()], opt_rand);
assert_eq!(sig.to_vec(), $test_case.signature.data);
}};
}
Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/acvp_ver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ macro_rules! parameter_case {
($param:ident, $test_case:expr) => {{
let sk = VerifyingKey::<$param>::try_from($test_case.pk.data.as_slice()).unwrap();
if let Ok(sig) = $test_case.signature.data.as_slice().try_into() {
let success = sk.slh_verify_internal($test_case.message.data.as_slice(), &sig);
let success = sk.slh_verify_internal(&[$test_case.message.data.as_slice()], &sig);
assert_eq!($test_case.testPassed, success.is_ok());
} else {
assert!(!$test_case.testPassed);
Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/known_answer_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ where
let mut opt_rand = vec![0; P::VkLen::USIZE / 2];
rng.fill_bytes(opt_rand.as_mut());

let sig = sk.slh_sign_internal(msg, Some(&opt_rand)).to_bytes();
let sig = sk.slh_sign_internal(&[msg], Some(&opt_rand)).to_bytes();
writeln!(resp, "smlen = {}", sig.as_slice().len() + msg.len()).unwrap();
writeln!(
resp,
Expand Down