Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aead: Support stacked borrows model using a new InOut type. #2164

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/aead.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2015-2021 Brian Smith.
// Copyright 2015-2024 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
Expand Down Expand Up @@ -34,6 +34,7 @@ pub use self::{
sealing_key::SealingKey,
unbound_key::UnboundKey,
};
use inout::InOut;

/// A sequences of unique nonces.
///
Expand Down Expand Up @@ -175,6 +176,7 @@ mod chacha;
mod chacha20_poly1305;
pub mod chacha20_poly1305_openssh;
mod gcm;
mod inout;
mod less_safe_key;
mod nonce;
mod opening_key;
Expand Down
7 changes: 3 additions & 4 deletions src/aead/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{nonce::Nonce, quic::Sample, NONCE_LEN};
use super::{nonce::Nonce, quic::Sample, InOut, NONCE_LEN};
use crate::{
constant_time,
cpu::{self, GetFeature as _},
error,
};
use cfg_if::cfg_if;
use core::ops::RangeFrom;

pub(super) use ffi::Counter;

Expand Down Expand Up @@ -158,7 +157,7 @@ pub(super) trait EncryptBlock {
}

pub(super) trait EncryptCtr32 {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter);
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter);
}

#[allow(dead_code)]
Expand All @@ -178,7 +177,7 @@ fn encrypt_iv_xor_block_using_encrypt_block(
#[allow(dead_code)]
fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block {
let mut ctr = Counter(iv.0); // This is OK because we're only encrypting one block.
key.ctr32_encrypt_within(&mut block, 0.., &mut ctr);
key.ctr32_encrypt_within(InOut::in_place(&mut block), &mut ctr);
block
}

Expand Down
8 changes: 3 additions & 5 deletions src/aead/aes/bs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

#![cfg(target_arch = "arm")]

use super::{Counter, AES_KEY};
use core::ops::RangeFrom;
use super::{Counter, InOut, AES_KEY};

/// SAFETY:
/// * The caller must ensure that if blocks > 0 then either `input` and
Expand All @@ -28,8 +27,7 @@ use core::ops::RangeFrom;
/// * Upon returning, `blocks` blocks will have been read from `input` and
/// written to `output`.
pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key(
in_out: &mut [u8],
src: RangeFrom<usize>,
in_out: InOut<'_>,
vpaes_key: &AES_KEY,
ctr: &mut Counter,
) {
Expand Down Expand Up @@ -57,6 +55,6 @@ pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key(
// * `bsaes_ctr32_encrypt_blocks` satisfies the contract for
// `ctr32_encrypt_blocks`.
unsafe {
ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, src, &bsaes_key, ctr);
ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, &bsaes_key, ctr);
}
}
9 changes: 3 additions & 6 deletions src/aead/aes/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::error;
use core::ops::RangeFrom;

#[derive(Clone)]
pub struct Key {
Expand All @@ -39,9 +38,7 @@ impl EncryptBlock for Key {
}

impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
unsafe {
ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr)
}
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}
24 changes: 11 additions & 13 deletions src/aead/aes/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{Block, KeyBytes, BLOCK_LEN};
use crate::{bits::BitLength, c, error, polyfill::slice};
use core::{num::NonZeroUsize, ops::RangeFrom};
use super::{Block, InOut, KeyBytes, BLOCK_LEN};
use crate::{bits::BitLength, c, error};
use core::num::NonZeroUsize;

/// nonce || big-endian counter.
#[repr(transparent)]
Expand Down Expand Up @@ -127,7 +127,7 @@ impl AES_KEY {
/// * The caller must ensure that fhe function `$name` satisfies the conditions
/// for the `f` parameter to `ctr32_encrypt_blocks`.
macro_rules! ctr32_encrypt_blocks {
($name:ident, $in_out:expr, $src:expr, $key:expr, $ctr:expr $(,)? ) => {{
($name:ident, $in_out:expr, $key:expr, $ctr:expr $(,)? ) => {{
use crate::{
aead::aes::{ffi::AES_KEY, Counter, BLOCK_LEN},
c,
Expand All @@ -141,7 +141,7 @@ macro_rules! ctr32_encrypt_blocks {
ivec: &Counter,
);
}
$key.ctr32_encrypt_blocks($name, $in_out, $src, $ctr)
$key.ctr32_encrypt_blocks($name, $in_out, $ctr)
}};
}

Expand All @@ -167,25 +167,23 @@ impl AES_KEY {
key: &AES_KEY,
ivec: &Counter,
),
in_out: &mut [u8],
src: RangeFrom<usize>,
mut in_out: InOut<'_>,
ctr: &mut Counter,
) {
let (input, leftover) = slice::as_chunks(&in_out[src]);
debug_assert_eq!(leftover.len(), 0);
let (input, output, len) = in_out.input_output_len();
debug_assert_eq!(len % BLOCK_LEN, 0);

let blocks = match NonZeroUsize::new(input.len()) {
let blocks = match NonZeroUsize::new(len / BLOCK_LEN) {
Some(blocks) => blocks,
None => {
return;
}
};

let input: *const [u8; BLOCK_LEN] = input.cast();
let output: *mut [u8; BLOCK_LEN] = output.cast();
let blocks_u32: u32 = blocks.get().try_into().unwrap();

let input = input.as_ptr();
let output: *mut [u8; BLOCK_LEN] = in_out.as_mut_ptr().cast();

// SAFETY:
// * `input` points to `blocks` blocks.
// * `output` points to space for `blocks` blocks to be written.
Expand Down
7 changes: 3 additions & 4 deletions src/aead/aes/hw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

#![cfg(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64"))]

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::{cpu, error};
use core::ops::RangeFrom;

#[cfg(target_arch = "aarch64")]
pub(in super::super) type RequiredCpuFeatures = cpu::arm::Aes;
Expand Down Expand Up @@ -56,9 +55,9 @@ impl EncryptBlock for Key {
}

impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
#[cfg(target_arch = "x86_64")]
let _: cpu::Features = cpu::features();
unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}
28 changes: 13 additions & 15 deletions src/aead/aes/vp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
target_arch = "x86_64"
))]

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::{cpu, error};
use core::ops::RangeFrom;

#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
type RequiredCpuFeatures = cpu::arm::Neon;
Expand Down Expand Up @@ -57,17 +56,18 @@ impl EncryptBlock for Key {

#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}

#[cfg(target_arch = "arm")]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
use super::{bs, BLOCK_LEN};

let in_out = {
let (in_out, src) = in_out.into_slice_src_mut();
let blocks = in_out[src.clone()].len() / BLOCK_LEN;

// bsaes operates in batches of 8 blocks.
Expand All @@ -84,28 +84,26 @@ impl EncryptCtr32 for Key {
0
};
let bsaes_in_out_len = bsaes_blocks * BLOCK_LEN;
let bs_in_out =
InOut::overlapping(&mut in_out[..(src.start + bsaes_in_out_len)], src.clone())
.unwrap();

// SAFETY:
// * self.inner was initialized with `vpaes_set_encrypt_key` above,
// as required by `bsaes_ctr32_encrypt_blocks_with_vpaes_key`.
unsafe {
bs::ctr32_encrypt_blocks_with_vpaes_key(
&mut in_out[..(src.start + bsaes_in_out_len)],
src.clone(),
&self.inner,
ctr,
);
bs::ctr32_encrypt_blocks_with_vpaes_key(bs_in_out, &self.inner, ctr);
}

&mut in_out[bsaes_in_out_len..]
InOut::overlapping(&mut in_out[bsaes_in_out_len..], src).unwrap()
};

// SAFETY:
// * self.inner was initialized with `vpaes_set_encrypt_key` above,
// as required by `vpaes_ctr32_encrypt_blocks`.
// * `vpaes_ctr32_encrypt_blocks` satisfies the contract for
// `ctr32_encrypt_blocks`.
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}

Expand All @@ -122,8 +120,8 @@ impl EncryptBlock for Key {

#[cfg(target_arch = "x86")]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
super::super::shift::shift_full_blocks(in_out, src, |input| {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
super::super::shift::shift_full_blocks(in_out, |input| {
self.encrypt_iv_xor_block(ctr.increment(), *input)
});
}
Expand Down
19 changes: 8 additions & 11 deletions src/aead/aes_gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use super::{
aes::{self, Counter, BLOCK_LEN, ZERO_BLOCK},
gcm, shift, Aad, Nonce, Tag,
gcm, shift, Aad, InOut, Nonce, Tag,
};
use crate::{
cpu, error,
Expand Down Expand Up @@ -160,7 +160,7 @@ pub(super) fn seal(
}
};
let (whole, remainder) = slice::as_chunks_mut(ramaining);
aes_key.ctr32_encrypt_within(slice::flatten_mut(whole), 0.., &mut ctr);
aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(whole)), &mut ctr);
auth.update_blocks(whole);
seal_finish(aes_key, auth, remainder, ctr, tag_iv)
}
Expand Down Expand Up @@ -240,7 +240,7 @@ fn seal_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
let (whole, remainder) = slice::as_chunks_mut(in_out);

for chunk in whole.chunks_mut(CHUNK_BLOCKS) {
aes_key.ctr32_encrypt_within(slice::flatten_mut(chunk), 0.., &mut ctr);
aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(chunk)), &mut ctr);
auth.update_blocks(chunk);
}

Expand Down Expand Up @@ -331,11 +331,8 @@ pub(super) fn open(
let whole_len = slice::flatten(whole).len();

// Decrypt any remaining whole blocks.
aes_key.ctr32_encrypt_within(
&mut in_out[..(src.start + whole_len)],
src.clone(),
&mut ctr,
);
let whole = InOut::overlapping(&mut in_out[..(src.start + whole_len)], src.clone())?;
aes_key.ctr32_encrypt_within(whole, &mut ctr);

let in_out = match in_out.get_mut(whole_len..) {
Some(partial) => partial,
Expand Down Expand Up @@ -450,11 +447,11 @@ fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
}
auth.update_blocks(ciphertext);

aes_key.ctr32_encrypt_within(
let chunk = InOut::overlapping(
&mut in_out[output..][..(chunk_len + in_prefix_len)],
in_prefix_len..,
&mut ctr,
);
)?;
aes_key.ctr32_encrypt_within(chunk, &mut ctr);
output += chunk_len;
input += chunk_len;
}
Expand Down
Loading