diff --git a/src/aead.rs b/src/aead.rs index 20b460be98..2760c76f23 100644 --- a/src/aead.rs +++ b/src/aead.rs @@ -1,4 +1,4 @@ -// Copyright 2015-2021 Brian Smith. +// Copyright 2015-2024 Brian Smith. // // Permission to use, copy, modify, and/or distribute this software for any // purpose with or without fee is hereby granted, provided that the above @@ -34,6 +34,7 @@ pub use self::{ sealing_key::SealingKey, unbound_key::UnboundKey, }; +use inout::InOut; /// A sequences of unique nonces. /// @@ -175,6 +176,7 @@ mod chacha; mod chacha20_poly1305; pub mod chacha20_poly1305_openssh; mod gcm; +mod inout; mod less_safe_key; mod nonce; mod opening_key; diff --git a/src/aead/aes.rs b/src/aead/aes.rs index f3cd35be52..a1bf4094be 100644 --- a/src/aead/aes.rs +++ b/src/aead/aes.rs @@ -12,14 +12,13 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{nonce::Nonce, quic::Sample, NONCE_LEN}; +use super::{nonce::Nonce, quic::Sample, InOut, NONCE_LEN}; use crate::{ constant_time, cpu::{self, GetFeature as _}, error, }; use cfg_if::cfg_if; -use core::ops::RangeFrom; pub(super) use ffi::Counter; @@ -158,7 +157,7 @@ pub(super) trait EncryptBlock { } pub(super) trait EncryptCtr32 { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter); + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter); } #[allow(dead_code)] @@ -178,7 +177,7 @@ fn encrypt_iv_xor_block_using_encrypt_block( #[allow(dead_code)] fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block { let mut ctr = Counter(iv.0); // This is OK because we're only encrypting one block. - key.ctr32_encrypt_within(&mut block, 0.., &mut ctr); + key.ctr32_encrypt_within(InOut::in_place(&mut block), &mut ctr); block } diff --git a/src/aead/aes/bs.rs b/src/aead/aes/bs.rs index f1c5408a26..884b7ec470 100644 --- a/src/aead/aes/bs.rs +++ b/src/aead/aes/bs.rs @@ -14,8 +14,7 @@ #![cfg(target_arch = "arm")] -use super::{Counter, AES_KEY}; -use core::ops::RangeFrom; +use super::{Counter, InOut, AES_KEY}; /// SAFETY: /// * The caller must ensure that if blocks > 0 then either `input` and @@ -28,8 +27,7 @@ use core::ops::RangeFrom; /// * Upon returning, `blocks` blocks will have been read from `input` and /// written to `output`. pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key( - in_out: &mut [u8], - src: RangeFrom, + in_out: InOut<'_>, vpaes_key: &AES_KEY, ctr: &mut Counter, ) { @@ -57,6 +55,6 @@ pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key( // * `bsaes_ctr32_encrypt_blocks` satisfies the contract for // `ctr32_encrypt_blocks`. unsafe { - ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, src, &bsaes_key, ctr); + ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, &bsaes_key, ctr); } } diff --git a/src/aead/aes/fallback.rs b/src/aead/aes/fallback.rs index 00caa694ab..0a475768c1 100644 --- a/src/aead/aes/fallback.rs +++ b/src/aead/aes/fallback.rs @@ -12,9 +12,8 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY}; use crate::error; -use core::ops::RangeFrom; #[derive(Clone)] pub struct Key { @@ -39,9 +38,7 @@ impl EncryptBlock for Key { } impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { - unsafe { - ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) - } + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) { + unsafe { ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) } } } diff --git a/src/aead/aes/ffi.rs b/src/aead/aes/ffi.rs index 840845059b..c371582e66 100644 --- a/src/aead/aes/ffi.rs +++ b/src/aead/aes/ffi.rs @@ -12,9 +12,9 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{Block, KeyBytes, BLOCK_LEN}; -use crate::{bits::BitLength, c, error, polyfill::slice}; -use core::{num::NonZeroUsize, ops::RangeFrom}; +use super::{Block, InOut, KeyBytes, BLOCK_LEN}; +use crate::{bits::BitLength, c, error}; +use core::num::NonZeroUsize; /// nonce || big-endian counter. #[repr(transparent)] @@ -127,7 +127,7 @@ impl AES_KEY { /// * The caller must ensure that fhe function `$name` satisfies the conditions /// for the `f` parameter to `ctr32_encrypt_blocks`. macro_rules! ctr32_encrypt_blocks { - ($name:ident, $in_out:expr, $src:expr, $key:expr, $ctr:expr $(,)? ) => {{ + ($name:ident, $in_out:expr, $key:expr, $ctr:expr $(,)? ) => {{ use crate::{ aead::aes::{ffi::AES_KEY, Counter, BLOCK_LEN}, c, @@ -141,7 +141,7 @@ macro_rules! ctr32_encrypt_blocks { ivec: &Counter, ); } - $key.ctr32_encrypt_blocks($name, $in_out, $src, $ctr) + $key.ctr32_encrypt_blocks($name, $in_out, $ctr) }}; } @@ -167,25 +167,23 @@ impl AES_KEY { key: &AES_KEY, ivec: &Counter, ), - in_out: &mut [u8], - src: RangeFrom, + mut in_out: InOut<'_>, ctr: &mut Counter, ) { - let (input, leftover) = slice::as_chunks(&in_out[src]); - debug_assert_eq!(leftover.len(), 0); + let (input, output, len) = in_out.input_output_len(); + debug_assert_eq!(len % BLOCK_LEN, 0); - let blocks = match NonZeroUsize::new(input.len()) { + let blocks = match NonZeroUsize::new(len / BLOCK_LEN) { Some(blocks) => blocks, None => { return; } }; + let input: *const [u8; BLOCK_LEN] = input.cast(); + let output: *mut [u8; BLOCK_LEN] = output.cast(); let blocks_u32: u32 = blocks.get().try_into().unwrap(); - let input = input.as_ptr(); - let output: *mut [u8; BLOCK_LEN] = in_out.as_mut_ptr().cast(); - // SAFETY: // * `input` points to `blocks` blocks. // * `output` points to space for `blocks` blocks to be written. diff --git a/src/aead/aes/hw.rs b/src/aead/aes/hw.rs index c7b1e51de7..d5241b1699 100644 --- a/src/aead/aes/hw.rs +++ b/src/aead/aes/hw.rs @@ -14,9 +14,8 @@ #![cfg(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64"))] -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY}; use crate::{cpu, error}; -use core::ops::RangeFrom; #[cfg(target_arch = "aarch64")] pub(in super::super) type RequiredCpuFeatures = cpu::arm::Aes; @@ -56,9 +55,9 @@ impl EncryptBlock for Key { } impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) { #[cfg(target_arch = "x86_64")] let _: cpu::Features = cpu::features(); - unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) } } } diff --git a/src/aead/aes/vp.rs b/src/aead/aes/vp.rs index 0893a9873c..6b0bdbfa15 100644 --- a/src/aead/aes/vp.rs +++ b/src/aead/aes/vp.rs @@ -19,9 +19,8 @@ target_arch = "x86_64" ))] -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY}; use crate::{cpu, error}; -use core::ops::RangeFrom; #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] type RequiredCpuFeatures = cpu::arm::Neon; @@ -57,17 +56,18 @@ impl EncryptBlock for Key { #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { - unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) { + unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) } } } #[cfg(target_arch = "arm")] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) { use super::{bs, BLOCK_LEN}; let in_out = { + let (in_out, src) = in_out.into_slice_src_mut(); let blocks = in_out[src.clone()].len() / BLOCK_LEN; // bsaes operates in batches of 8 blocks. @@ -84,20 +84,18 @@ impl EncryptCtr32 for Key { 0 }; let bsaes_in_out_len = bsaes_blocks * BLOCK_LEN; + let bs_in_out = + InOut::overlapping(&mut in_out[..(src.start + bsaes_in_out_len)], src.clone()) + .unwrap(); // SAFETY: // * self.inner was initialized with `vpaes_set_encrypt_key` above, // as required by `bsaes_ctr32_encrypt_blocks_with_vpaes_key`. unsafe { - bs::ctr32_encrypt_blocks_with_vpaes_key( - &mut in_out[..(src.start + bsaes_in_out_len)], - src.clone(), - &self.inner, - ctr, - ); + bs::ctr32_encrypt_blocks_with_vpaes_key(bs_in_out, &self.inner, ctr); } - &mut in_out[bsaes_in_out_len..] + InOut::overlapping(&mut in_out[bsaes_in_out_len..], src).unwrap() }; // SAFETY: @@ -105,7 +103,7 @@ impl EncryptCtr32 for Key { // as required by `vpaes_ctr32_encrypt_blocks`. // * `vpaes_ctr32_encrypt_blocks` satisfies the contract for // `ctr32_encrypt_blocks`. - unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) } } } @@ -122,8 +120,8 @@ impl EncryptBlock for Key { #[cfg(target_arch = "x86")] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { - super::super::shift::shift_full_blocks(in_out, src, |input| { + fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) { + super::super::shift::shift_full_blocks(in_out, |input| { self.encrypt_iv_xor_block(ctr.increment(), *input) }); } diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index feb4df1e28..23cecf2427 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -14,7 +14,7 @@ use super::{ aes::{self, Counter, BLOCK_LEN, ZERO_BLOCK}, - gcm, shift, Aad, Nonce, Tag, + gcm, shift, Aad, InOut, Nonce, Tag, }; use crate::{ cpu, error, @@ -160,7 +160,7 @@ pub(super) fn seal( } }; let (whole, remainder) = slice::as_chunks_mut(ramaining); - aes_key.ctr32_encrypt_within(slice::flatten_mut(whole), 0.., &mut ctr); + aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(whole)), &mut ctr); auth.update_blocks(whole); seal_finish(aes_key, auth, remainder, ctr, tag_iv) } @@ -240,7 +240,7 @@ fn seal_strided partial, @@ -450,11 +447,11 @@ fn open_strided { + in_out: &'i mut [u8], + src: RangeFrom, +} + +impl<'i> InOut<'i> { + pub fn in_place(in_out: &'i mut [u8]) -> Self { + Self { in_out, src: 0.. } + } + + pub fn overlapping(in_out: &'i mut [u8], src: RangeFrom) -> Result { + match in_out.get(src.clone()) { + Some(_) => Ok(Self { in_out, src }), + None => Err(SrcIndexError::new(src)), + } + } + + #[cfg(any(target_arch = "arm", target_arch = "x86"))] + pub fn into_slice_src_mut(self) -> (&'i mut [u8], RangeFrom) { + (self.in_out, self.src) + } +} + +impl InOut<'_> { + pub fn len(&self) -> usize { + self.in_out[self.src.clone()].len() + } + pub fn input_output_len(&mut self) -> (*const u8, *mut u8, usize) { + let len = self.len(); + let output = self.in_out.as_mut_ptr(); + // TODO: MSRV(1.65): use `output.cast_const()` + let output_const: *const u8 = output; + // SAFETY: The constructor ensures that `src` is a valid range. + // Equivalent to `self.in_out[src.clone()].as_ptr()` but without + // worries about compatibility with the stacked borrows model. + let input = unsafe { output_const.add(self.src.start) }; + (input, output, len) + } +} + +#[derive(Debug)] +pub struct SrcIndexError(#[allow(dead_code)] RangeFrom); + +impl SrcIndexError { + #[cold] + fn new(src: RangeFrom) -> Self { + Self(src) + } +} + +impl From for error::Unspecified { + fn from(_: SrcIndexError) -> Self { + Self + } +} diff --git a/src/aead/shift.rs b/src/aead/shift.rs index fc2227378f..4b6648f646 100644 --- a/src/aead/shift.rs +++ b/src/aead/shift.rs @@ -16,10 +16,10 @@ use crate::polyfill::sliceutil::overwrite_at_start; #[cfg(target_arch = "x86")] pub fn shift_full_blocks( - in_out: &mut [u8], - src: core::ops::RangeFrom, + in_out: super::InOut<'_>, mut transform: impl FnMut(&[u8; BLOCK_LEN]) -> [u8; BLOCK_LEN], ) { + let (in_out, src) = in_out.into_slice_src_mut(); let in_out_len = in_out[src.clone()].len(); for i in (0..in_out_len).step_by(BLOCK_LEN) {