diff --git a/esp-hal/CHANGELOG.md b/esp-hal/CHANGELOG.md index 8cc18aba81b..19ca541783f 100644 --- a/esp-hal/CHANGELOG.md +++ b/esp-hal/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove `fn free(self)` in HMAC which goes against esp-hal API guidelines (#1972) - PARL_IO use ReadBuffer and WriteBuffer for Async DMA (#1996) - `AnyPin`, `AnyInputOnyPin` and `DummyPin` are now accessible from `gpio` module (#1918) +- Changed the RSA modular multiplication API to be consistent across devices (#2002) ### Fixed @@ -45,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Reset peripherals in driver constructors where missing (#1893, #1961) - Fixed ESP32-S2 systimer interrupts (#1979) - Software interrupt 3 is no longer available when it is required by `esp-hal-embassy`. (#2011) +- ESP32: Fixed async RSA (#2002) ### Removed diff --git a/esp-hal/src/rsa/esp32.rs b/esp-hal/src/rsa/esp32.rs index bb9a3796dc8..c5b9938d824 100644 --- a/esp-hal/src/rsa/esp32.rs +++ b/esp-hal/src/rsa/esp32.rs @@ -1,8 +1,4 @@ -use core::{ - convert::Infallible, - marker::PhantomData, - ptr::{copy_nonoverlapping, write_bytes}, -}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -37,35 +33,30 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .modexp_start() .write(|w| w.modexp_start().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa.mult_start().write(|w| w.mult_start().set_bit()); } + /// Starts the modular multiplication operation. + pub(super) fn write_modmulti_start(&self) { + self.write_multi_start(); + } + /// Clears the RSA interrupt flag. pub(super) fn clear_interrupt(&mut self) { self.rsa.interrupt().write(|w| w.interrupt().set_bit()); } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { - self.rsa.interrupt().read().bits() == 1 - } - - unsafe fn write_multi_operand_a(&mut self, operand_a: &[u32; N]) { - copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); - write_bytes(self.rsa.x_mem(0).as_ptr().add(N), 0, N); - } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - write_bytes(self.rsa.z_mem(0).as_ptr(), 0, N); - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); + pub(super) fn is_idle(&self) -> bool { + self.rsa.interrupt().read().interrupt().bit_is_set() } } @@ -92,59 +83,18 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 24.3.2 of . - pub fn new(rsa: &'a mut Rsa<'d, DM>, modulus: &T::InputType, m_prime: u32) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - - Self { - rsa, - phantom: PhantomData, - } - } - - fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_multi_mode((N / 16 - 1) as u32) } - /// Starts the first step of modular multiplication operation. - /// - /// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. + /// Starts the modular multiplication operation. /// /// For more information refer to 24.3.2 of . - pub fn start_step1(&mut self, operand_a: &T::InputType, r: &T::InputType) { - unsafe { - self.rsa.write_operand_a(operand_a); - self.rsa.write_r(r); - } - self.start(); - } - - /// Starts the second step of modular multiplication operation. - /// - /// This is a non blocking function that returns without an error if - /// operation is completed successfully. `start_step1` must be called - /// before calling this function. - pub fn start_step2(&mut self, operand_b: &T::InputType) { - while !self.rsa.is_idle() {} - - self.rsa.clear_interrupt(); - unsafe { - self.rsa.write_operand_a(operand_b); - } - self.start(); - } - - fn start(&mut self) { + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { self.rsa.write_multi_start(); + self.rsa.wait_for_idle(); + + self.rsa.write_operand_a(operand_b); } } @@ -152,70 +102,22 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 24.3.2 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - Self { - rsa, - phantom: PhantomData, - } - } - /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_modexp_mode((N / 16 - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplication<'a, 'd, T, DM> where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>) -> Self { - Self::set_mode(rsa); - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_a: &T::InputType, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_a(operand_a); - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); - } - /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_multi_mode(((N * 2) / 16 + 7) as u32) } - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } } diff --git a/esp-hal/src/rsa/esp32cX.rs b/esp-hal/src/rsa/esp32cX.rs index c09fa73902e..ea6d79d44c2 100644 --- a/esp-hal/src/rsa/esp32cX.rs +++ b/esp-hal/src/rsa/esp32cX.rs @@ -1,4 +1,4 @@ -use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -94,21 +94,21 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .set_start_modexp() .write(|w| w.set_start_modexp().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa .set_start_mult() .write(|w| w.set_start_mult().set_bit()); } /// Starts the modular multiplication operation. - fn write_modmulti_start(&mut self) { + pub(super) fn write_modmulti_start(&self) { self.rsa .set_start_modmult() .write(|w| w.set_start_modmult().set_bit()); @@ -120,13 +120,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { + pub(super) fn is_idle(&self) -> bool { self.rsa.query_idle().read().query_idle().bit_is_set() } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); - } } /// Module defining marker types for various RSA operand sizes. @@ -240,34 +236,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` could be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 19.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - if rsa.is_search_enabled() { - rsa.write_search_position(Self::find_search_pos(exponent)); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn find_search_pos(exponent: &T::InputType) -> u32 { + pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 { for (i, byte) in exponent.iter().rev().enumerate() { if *byte == 0 { continue; @@ -278,64 +247,21 @@ where } /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM> where T: RsaMode, { - fn write_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - /// Creates an instance of `RsaModularMultiplication`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 19.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - operand_a: &T::InputType, - operand_b: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::write_mode(rsa); - rsa.write_mprime(m_prime); - unsafe { - rsa.write_modulus(modulus); - rsa.write_operand_a(operand_a); - rsa.write_operand_b(operand_b); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the modular multiplication operation. - /// - /// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. - /// - /// For more information refer to 19.3.1 of . - pub fn start_modular_multiplication(&mut self, r: &T::InputType) { - unsafe { - self.rsa.write_r(r); - } - self.start(); - } - - fn start(&mut self) { - self.rsa.write_modmulti_start(); + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_operand_b(operand_b); } } @@ -343,33 +269,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_a(operand_a); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N * 2 - 1) as u32) } - - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); - } } diff --git a/esp-hal/src/rsa/esp32sX.rs b/esp-hal/src/rsa/esp32sX.rs index aefab89a668..956ec3fe4e8 100644 --- a/esp-hal/src/rsa/esp32sX.rs +++ b/esp-hal/src/rsa/esp32sX.rs @@ -1,4 +1,4 @@ -use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; +use core::convert::Infallible; use crate::rsa::{ implement_op, @@ -101,19 +101,19 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Starts the modular exponentiation operation. - pub(super) fn write_modexp_start(&mut self) { + pub(super) fn write_modexp_start(&self) { self.rsa .modexp_start() .write(|w| w.modexp_start().set_bit()); } /// Starts the multiplication operation. - pub(super) fn write_multi_start(&mut self) { + pub(super) fn write_multi_start(&self) { self.rsa.mult_start().write(|w| w.mult_start().set_bit()); } /// Starts the modular multiplication operation. - fn write_modmulti_start(&mut self) { + pub(super) fn write_modmulti_start(&self) { self.rsa .modmult_start() .write(|w| w.modmult_start().set_bit()); @@ -127,13 +127,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } /// Checks if the RSA peripheral is idle. - pub(super) fn is_idle(&mut self) -> bool { + pub(super) fn is_idle(&self) -> bool { self.rsa.idle().read().idle().bit_is_set() } - - unsafe fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); - } } pub mod operand_sizes { @@ -281,34 +277,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { - /// Creates an instance of `RsaModularExponentiation`. - /// - /// `m_prime` can be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 20.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - exponent: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_b(exponent); - rsa.write_modulus(modulus); - } - rsa.write_mprime(m_prime); - if rsa.is_search_enabled() { - rsa.write_search_position(Self::find_search_pos(exponent)); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn find_search_pos(exponent: &T::InputType) -> u32 { + pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 { for (i, byte) in exponent.iter().rev().enumerate() { if *byte == 0 { continue; @@ -319,64 +288,21 @@ where } /// Sets the modular exponentiation mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - - /// Starts the modular exponentiation operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_modexp_start(); - } } impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM> where T: RsaMode, { - /// Creates an instance of `RsaModularMultiplication`. - /// - /// `m_prime` could be calculated using `-(modular multiplicative inverse of - /// modulus) mod 2^32`. - /// - /// For more information refer to 20.3.1 of . - pub fn new( - rsa: &'a mut Rsa<'d, DM>, - operand_a: &T::InputType, - operand_b: &T::InputType, - modulus: &T::InputType, - m_prime: u32, - ) -> Self { - Self::write_mode(rsa); - rsa.write_mprime(m_prime); - unsafe { - rsa.write_modulus(modulus); - rsa.write_operand_a(operand_a); - rsa.write_operand_b(operand_b); - } - Self { - rsa, - phantom: PhantomData, - } - } - - fn write_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N - 1) as u32) } - /// Starts the modular multiplication operation. - /// - /// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. - /// - /// For more information refer to 19.3.1 of . - pub fn start_modular_multiplication(&mut self, r: &T::InputType) { - unsafe { - self.rsa.write_r(r); - } - self.start(); - } - - fn start(&mut self) { - self.rsa.write_modmulti_start(); + pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_operand_b(operand_b); } } @@ -384,33 +310,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { - /// Creates an instance of `RsaMultiplication`. - pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { - Self::set_mode(rsa); - unsafe { - rsa.write_operand_a(operand_a); - } - Self { - rsa, - phantom: PhantomData, - } - } - - /// Starts the multiplication operation. - pub fn start_multiplication(&mut self, operand_b: &T::InputType) { - unsafe { - self.rsa.write_multi_operand_b(operand_b); - } - self.start(); - } - /// Sets the multiplication mode for the RSA hardware. - pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) { + pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) { rsa.write_mode((N * 2 - 1) as u32) } - /// Starts the multiplication operation on the RSA hardware. - pub(super) fn start(&mut self) { - self.rsa.write_multi_start(); + pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) { + self.rsa.write_multi_operand_b(operand_b); } } diff --git a/esp-hal/src/rsa/mod.rs b/esp-hal/src/rsa/mod.rs index e8c54c550b2..8bf11d0eafa 100644 --- a/esp-hal/src/rsa/mod.rs +++ b/esp-hal/src/rsa/mod.rs @@ -16,16 +16,10 @@ //! ## Examples //! //! ### Modular Exponentiation, Modular Multiplication, and Multiplication -//! Visit the [RSA test] for an example of using the peripheral. -//! -//! ## Implementation State -//! -//! - The [nb] crate is used to handle non-blocking operations. -//! - This peripheral supports `async` on every available chip except of `esp32` -//! (to be solved). +//! Visit the [RSA test suite] for an example of using the peripheral. //! //! [nb]: https://docs.rs/nb/1.1.0/nb/ -//! [RSA test]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs +//! [RSA test suite]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs use core::{marker::PhantomData, ptr::copy_nonoverlapping}; @@ -53,24 +47,6 @@ pub struct Rsa<'d, DM: crate::Mode> { phantom: PhantomData, } -impl<'d, DM: crate::Mode> Rsa<'d, DM> { - fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) { - unsafe { - crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler()); - crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority()) - .unwrap(); - } - } - - fn read_results(&mut self, outbuf: &mut [u32; N]) { - while !self.is_idle() {} - unsafe { - self.read_out(outbuf); - } - self.clear_interrupt(); - } -} - impl<'d> Rsa<'d, crate::Blocking> { /// Create a new instance in [crate::Blocking] mode. /// @@ -111,32 +87,66 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> { } } - unsafe fn write_operand_b(&mut self, operand_b: &[u32; N]) { - copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N); + fn write_operand_b(&mut self, operand_b: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N); + } } - unsafe fn write_modulus(&mut self, modulus: &[u32; N]) { - copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N); + fn write_modulus(&mut self, modulus: &[u32; N]) { + unsafe { + copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N); + } } fn write_mprime(&mut self, m_prime: u32) { self.rsa.m_prime().write(|w| unsafe { w.bits(m_prime) }); } - unsafe fn write_operand_a(&mut self, operand_a: &[u32; N]) { - copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); + fn write_operand_a(&mut self, operand_a: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N); + } } - unsafe fn write_r(&mut self, r: &[u32; N]) { - copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N); + fn write_multi_operand_b(&mut self, operand_b: &[u32; N]) { + unsafe { + copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N); + } } - unsafe fn read_out(&mut self, outbuf: &mut [u32; N]) { - copy_nonoverlapping( - self.rsa.z_mem(0).as_ptr() as *const u32, - outbuf.as_ptr() as *mut u32, - N, - ); + fn write_r(&mut self, r: &[u32; N]) { + unsafe { + copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N); + } + } + + fn read_out(&self, outbuf: &mut [u32; N]) { + unsafe { + copy_nonoverlapping( + self.rsa.z_mem(0).as_ptr() as *const u32, + outbuf.as_ptr() as *mut u32, + N, + ); + } + } + + fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) { + unsafe { + crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler()); + crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority()) + .unwrap(); + } + } + + fn wait_for_idle(&mut self) { + while !self.is_idle() {} + self.clear_interrupt(); + } + + fn read_results(&mut self, outbuf: &mut [u32; N]) { + self.wait_for_idle(); + self.read_out(outbuf); } } @@ -155,7 +165,7 @@ pub trait Multi: RsaMode { macro_rules! implement_op { (($x:literal, multi)) => { paste! { - /// Represents an RSA operation for the given bit size with multi-output. + #[doc = concat!($x, "-bit RSA operation.")] pub struct []; impl Multi for [] { @@ -204,17 +214,47 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati where T: RsaMode, { + /// Creates an instance of `RsaModularExponentiation`. + /// + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`. + /// + /// For more information refer to 24.3.2 of . + pub fn new( + rsa: &'a mut Rsa<'d, DM>, + exponent: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_operand_b(exponent); + rsa.write_modulus(modulus); + rsa.write_mprime(m_prime); + + #[cfg(not(esp32))] + if rsa.is_search_enabled() { + rsa.write_search_position(Self::find_search_pos(exponent)); + } + + Self { + rsa, + phantom: PhantomData, + } + } + + fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) { + self.rsa.write_operand_a(base); + self.rsa.write_r(r); + } + /// Starts the modular exponentiation operation. /// /// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. /// /// For more information refer to 24.3.2 of . pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) { - unsafe { - self.rsa.write_operand_a(base); - self.rsa.write_r(r); - } - self.start(); + self.set_up_exponentiation(base, r); + self.rsa.write_modexp_start(); } /// Reads the result to the given buffer. @@ -240,6 +280,40 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati where T: RsaMode, { + /// Creates an instance of `RsaModularMultiplication`. + /// + /// - `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`. + /// - `m_prime` can be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`. + /// + /// For more information refer to 20.3.1 of . + pub fn new( + rsa: &'a mut Rsa<'d, DM>, + operand_a: &T::InputType, + modulus: &T::InputType, + r: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_mprime(m_prime); + rsa.write_modulus(modulus); + rsa.write_operand_a(operand_a); + rsa.write_r(r); + + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the modular multiplication operation. + /// + /// For more information refer to 19.3.1 of . + pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) { + self.set_up_modular_multiplication(operand_b); + self.rsa.write_modmulti_start(); + } + /// Reads the result to the given buffer. /// This is a non blocking function that returns without an error if /// operation is completed successfully. @@ -261,6 +335,23 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat where T: RsaMode, { + /// Creates an instance of `RsaMultiplication`. + pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self { + Self::write_mode(rsa); + rsa.write_operand_a(operand_a); + + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the multiplication operation. + pub fn start_multiplication(&mut self, operand_b: &T::InputType) { + self.set_up_multiplication(operand_b); + self.rsa.write_multi_start(); + } + /// Reads the result to the given buffer. /// This is a non blocking function that returns without an error if /// operation is completed successfully. `start_multiplication` must be @@ -279,59 +370,67 @@ pub(crate) mod asynch { use core::task::Poll; use embassy_sync::waitqueue::AtomicWaker; + use portable_atomic::{AtomicBool, Ordering}; use procmacros::handler; - use crate::rsa::{ - Multi, - RsaMode, - RsaModularExponentiation, - RsaModularMultiplication, - RsaMultiplication, + use crate::{ + rsa::{ + Multi, + Rsa, + RsaMode, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + Async, }; static WAKER: AtomicWaker = AtomicWaker::new(); + static SIGNALED: AtomicBool = AtomicBool::new(false); + /// `Future` that waits for the RSA operation to complete. #[must_use = "futures do nothing unless you `.await` or poll them"] - pub(crate) struct RsaFuture<'d> { - instance: &'d crate::peripherals::RSA, + struct RsaFuture<'a, 'd> { + #[cfg_attr(esp32, allow(dead_code))] + instance: &'a Rsa<'d, Async>, } - impl<'d> RsaFuture<'d> { - /// Asynchronously initializes the RSA peripheral. - pub fn new(instance: &'d crate::peripherals::RSA) -> Self { + impl<'a, 'd> RsaFuture<'a, 'd> { + fn new(instance: &'a Rsa<'d, Async>) -> Self { + SIGNALED.store(false, Ordering::Relaxed); + cfg_if::cfg_if! { if #[cfg(esp32)] { - instance.interrupt().modify(|_, w| w.interrupt().set_bit()); } else if #[cfg(any(esp32s2, esp32s3))] { - instance.interrupt_ena().modify(|_, w| w.interrupt_ena().set_bit()); + instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().set_bit()); } else { - instance.int_ena().modify(|_, w| w.int_ena().set_bit()); + instance.rsa.int_ena().write(|w| w.int_ena().set_bit()); } } Self { instance } } - fn event_bit_is_clear(&self) -> bool { + fn is_done(&self) -> bool { + SIGNALED.load(Ordering::Acquire) + } + } + + impl Drop for RsaFuture<'_, '_> { + fn drop(&mut self) { cfg_if::cfg_if! { if #[cfg(esp32)] { - self.instance.interrupt().read().interrupt().bit_is_clear() } else if #[cfg(any(esp32s2, esp32s3))] { - self - .instance - .interrupt_ena() - .read() - .interrupt_ena() - .bit_is_clear() + self.instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().clear_bit()); } else { - self.instance.int_ena().read().int_ena().bit_is_clear() + self.instance.rsa.int_ena().write(|w| w.int_ena().clear_bit()); } } } } - impl<'d> core::future::Future for RsaFuture<'d> { + impl core::future::Future for RsaFuture<'_, '_> { type Output = (); fn poll( @@ -339,7 +438,7 @@ pub(crate) mod asynch { cx: &mut core::task::Context<'_>, ) -> core::task::Poll { WAKER.register(cx.waker()); - if self.event_bit_is_clear() { + if self.is_done() { Poll::Ready(()) } else { Poll::Pending @@ -347,7 +446,7 @@ pub(crate) mod asynch { } } - impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Async> where T: RsaMode, { @@ -358,49 +457,47 @@ pub(crate) mod asynch { r: &T::InputType, outbuf: &mut T::InputType, ) { - self.start_exponentiation(base, r); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + self.set_up_exponentiation(base, r); + let fut = RsaFuture::new(self.rsa); + self.rsa.write_modexp_start(); + fut.await; + self.rsa.read_out(outbuf); } } - impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, Async> where T: RsaMode, { - #[cfg(not(esp32))] /// Asynchronously performs an RSA modular multiplication operation. pub async fn modular_multiplication( &mut self, - r: &T::InputType, - outbuf: &mut T::InputType, - ) { - self.start_modular_multiplication(r); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); - } - - #[cfg(esp32)] - /// Asynchronously performs an RSA modular multiplication operation. - pub async fn modular_multiplication( - &mut self, - operand_a: &T::InputType, operand_b: &T::InputType, - r: &T::InputType, outbuf: &mut T::InputType, ) { - self.start_step1(operand_a, r); - self.start_step2(operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + cfg_if::cfg_if! { + if #[cfg(esp32)] { + let fut = RsaFuture::new(self.rsa); + self.rsa.write_multi_start(); + fut.await; + + self.rsa.write_operand_a(operand_b); + } else { + self.set_up_modular_multiplication(operand_b); + } + } + + let fut = RsaFuture::new(self.rsa); + self.rsa.write_modmulti_start(); + fut.await; + self.rsa.read_out(outbuf); } } - impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, crate::Async> + impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, Async> where T: RsaMode, { - #[cfg(not(esp32))] /// Asynchronously performs an RSA multiplication operation. pub async fn multiplication<'b, const O: usize>( &mut self, @@ -409,44 +506,28 @@ pub(crate) mod asynch { ) where T: Multi, { - self.start_multiplication(operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); - } - - #[cfg(esp32)] - /// Asynchronously performs an RSA multiplication operation. - pub async fn multiplication<'b, const O: usize>( - &mut self, - operand_a: &T::InputType, - operand_b: &T::InputType, - outbuf: &mut T::OutputType, - ) where - T: Multi, - { - self.start_multiplication(operand_a, operand_b); - RsaFuture::new(&self.rsa.rsa).await; - self.read_results(outbuf); + self.set_up_multiplication(operand_b); + let fut = RsaFuture::new(self.rsa); + self.rsa.write_multi_start(); + fut.await; + self.rsa.read_out(outbuf); } } #[handler] /// Interrupt handler for RSA. pub(super) fn rsa_interrupt_handler() { - #[cfg(not(any(esp32, esp32s2, esp32s3)))] - unsafe { &*crate::peripherals::RSA::ptr() } - .int_ena() - .modify(|_, w| w.int_ena().clear_bit()); - - #[cfg(esp32)] - unsafe { &*crate::peripherals::RSA::ptr() } - .interrupt() - .modify(|_, w| w.interrupt().clear_bit()); - - #[cfg(any(esp32s2, esp32s3))] - unsafe { &*crate::peripherals::RSA::ptr() } - .interrupt_ena() - .modify(|_, w| w.interrupt_ena().clear_bit()); + let rsa = unsafe { &*crate::peripherals::RSA::ptr() }; + SIGNALED.store(true, Ordering::Release); + cfg_if::cfg_if! { + if #[cfg(esp32)] { + rsa.interrupt().write(|w| w.interrupt().set_bit()); + } else if #[cfg(any(esp32s2, esp32s3))] { + rsa.clear_interrupt().write(|w| w.clear_interrupt().set_bit()); + } else { + rsa.int_clr().write(|w| w.clear_interrupt().set_bit()); + } + } WAKER.wake(); } diff --git a/hil-test/Cargo.toml b/hil-test/Cargo.toml index ba898bfadc4..f6e0287abd6 100644 --- a/hil-test/Cargo.toml +++ b/hil-test/Cargo.toml @@ -107,6 +107,10 @@ harness = false name = "rsa" harness = false +[[test]] +name = "rsa_async" +harness = false + [[test]] name = "sha" harness = false diff --git a/hil-test/tests/rsa.rs b/hil-test/tests/rsa.rs index d3e728c7659..c102dd025a5 100644 --- a/hil-test/tests/rsa.rs +++ b/hil-test/tests/rsa.rs @@ -10,7 +10,7 @@ use esp_hal::{ peripherals::Peripherals, prelude::*, rsa::{ - operand_sizes, + operand_sizes::*, Rsa, RsaModularExponentiation, RsaModularMultiplication, @@ -37,16 +37,6 @@ struct Context<'a> { rsa: Rsa<'a, Blocking>, } -impl Context<'_> { - pub fn init() -> Self { - let peripherals = Peripherals::take(); - let mut rsa = Rsa::new(peripherals.RSA); - nb::block!(rsa.ready()).unwrap(); - - Context { rsa } - } -} - const fn compute_r(modulus: &U512) -> U512 { let mut d = [0_u32; U512::LIMBS * 2 + 1]; d[d.len() - 1] = 1; @@ -68,10 +58,15 @@ mod tests { #[init] fn init() -> Context<'static> { - Context::init() + let peripherals = Peripherals::take(); + let mut rsa = Rsa::new(peripherals.RSA); + nb::block!(rsa.ready()).unwrap(); + + Context { rsa } } #[test] + #[timeout(5)] fn test_modular_exponentiation(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ 1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085, @@ -85,20 +80,20 @@ mod tests { ctx.rsa.enable_disable_search_acceleration(true); } let mut outbuf = [0_u32; U512::LIMBS]; - let mut mod_exp = RsaModularExponentiation::::new( + let mut mod_exp = RsaModularExponentiation::::new( &mut ctx.rsa, BIGNUM_2.as_words(), BIGNUM_3.as_words(), compute_mprime(&BIGNUM_3), ); let r = compute_r(&BIGNUM_3); - let base = &BIGNUM_1.as_words(); - mod_exp.start_exponentiation(&base, r.as_words()); + mod_exp.start_exponentiation(BIGNUM_1.as_words(), r.as_words()); mod_exp.read_results(&mut outbuf); assert_eq!(EXPECTED_OUTPUT, outbuf); } #[test] + #[timeout(5)] fn test_modular_multiplication(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ 1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870, @@ -107,31 +102,21 @@ mod tests { ]; let mut outbuf = [0_u32; U512::LIMBS]; - let mut mod_multi = - RsaModularMultiplication::::new( - &mut ctx.rsa, - #[cfg(not(feature = "esp32"))] - BIGNUM_1.as_words(), - #[cfg(not(feature = "esp32"))] - BIGNUM_2.as_words(), - BIGNUM_3.as_words(), - compute_mprime(&BIGNUM_3), - ); let r = compute_r(&BIGNUM_3); - #[cfg(feature = "esp32")] - { - mod_multi.start_step1(BIGNUM_1.as_words(), r.as_words()); - mod_multi.start_step2(BIGNUM_2.as_words()); - } - #[cfg(not(feature = "esp32"))] - { - mod_multi.start_modular_multiplication(r.as_words()); - } + let mut mod_multi = RsaModularMultiplication::::new( + &mut ctx.rsa, + BIGNUM_1.as_words(), + BIGNUM_3.as_words(), + r.as_words(), + compute_mprime(&BIGNUM_3), + ); + mod_multi.start_modular_multiplication(BIGNUM_2.as_words()); mod_multi.read_results(&mut outbuf); assert_eq!(EXPECTED_OUTPUT, outbuf); } #[test] + #[timeout(5)] fn test_multiplication(mut ctx: Context<'static>) { const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [ 1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767, @@ -145,21 +130,10 @@ mod tests { let operand_a = BIGNUM_1.as_words(); let operand_b = BIGNUM_2.as_words(); - cfg_if::cfg_if! { - if #[cfg(feature = "esp32")] { - let mut rsamulti = - RsaMultiplication::::new(&mut ctx.rsa); - rsamulti.start_multiplication(operand_a, operand_b); - rsamulti.read_results(&mut outbuf); - } else { - let mut rsamulti = RsaMultiplication::::new( - &mut ctx.rsa, - operand_a, - ); - rsamulti.start_multiplication(operand_b); - rsamulti.read_results(&mut outbuf); - } - } + let mut rsamulti = RsaMultiplication::::new(&mut ctx.rsa, operand_a); + rsamulti.start_multiplication(operand_b); + rsamulti.read_results(&mut outbuf); + assert_eq!(EXPECTED_OUTPUT, outbuf) } } diff --git a/hil-test/tests/rsa_async.rs b/hil-test/tests/rsa_async.rs new file mode 100644 index 00000000000..5730ea13c57 --- /dev/null +++ b/hil-test/tests/rsa_async.rs @@ -0,0 +1,140 @@ +//! Async RSA Test + +//% CHIPS: esp32 esp32c3 esp32c6 esp32h2 esp32s2 esp32s3 + +#![no_std] +#![no_main] + +use crypto_bigint::{Uint, U1024, U512}; +use esp_hal::{ + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes::*, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + Async, +}; +use hil_test as _; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ + 7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ + 6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ + 67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +struct Context<'a> { + rsa: Rsa<'a, Async>, +} + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[cfg(test)] +#[embedded_test::tests(executor = esp_hal_embassy::Executor::new())] +mod tests { + use defmt::assert_eq; + + use super::*; + + #[init] + fn init() -> Context<'static> { + let peripherals = Peripherals::take(); + let mut rsa = Rsa::new_async(peripherals.RSA); + nb::block!(rsa.ready()).unwrap(); + + Context { rsa } + } + + #[test] + #[timeout(5)] + async fn modular_exponentiation(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ + 1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085, + 1938128086, 625895085, 2087010412, 2133352910, 101578249, 3798099415, 3357588690, + 2065243474, 330914193, + ]; + + #[cfg(not(feature = "esp32"))] + { + ctx.rsa.enable_disable_constant_time_acceleration(true); + ctx.rsa.enable_disable_search_acceleration(true); + } + let mut outbuf = [0_u32; U512::LIMBS]; + let mut mod_exp = RsaModularExponentiation::::new( + &mut ctx.rsa, + BIGNUM_2.as_words(), + BIGNUM_3.as_words(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3); + mod_exp + .exponentiation(BIGNUM_1.as_words(), r.as_words(), &mut outbuf) + .await; + assert_eq!(EXPECTED_OUTPUT, outbuf); + } + + #[test] + #[timeout(5)] + async fn test_modular_multiplication(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [ + 1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870, + 4209598061, 3830489207, 1317083344, 2666864448, 3701382766, 3232598924, 2904609522, + 747558855, 479377985, + ]; + + let mut outbuf = [0_u32; U512::LIMBS]; + let r = compute_r(&BIGNUM_3); + let mut mod_multi = RsaModularMultiplication::::new( + &mut ctx.rsa, + BIGNUM_1.as_words(), + BIGNUM_3.as_words(), + r.as_words(), + compute_mprime(&BIGNUM_3), + ); + mod_multi + .modular_multiplication(BIGNUM_2.as_words(), &mut outbuf) + .await; + assert_eq!(EXPECTED_OUTPUT, outbuf); + } + + #[test] + #[timeout(5)] + async fn test_multiplication(mut ctx: Context<'static>) { + const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [ + 1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767, + 2269697177, 1525551459, 3623276361, 3146383138, 191420847, 4252021895, 9176459, + 301757643, 4220806186, 434407318, 3722444851, 1850128766, 928651940, 107896699, + 563405838, 1834067613, 1289630401, 3145128058, 3300293535, 3077505758, 1926648662, + 1264151247, 3626086486, 3701894076, 306518743, + ]; + let mut outbuf = [0_u32; U1024::LIMBS]; + + let operand_a = BIGNUM_1.as_words(); + let operand_b = BIGNUM_2.as_words(); + + let mut rsamulti = RsaMultiplication::::new(&mut ctx.rsa, operand_a); + rsamulti.multiplication(operand_b, &mut outbuf).await; + + assert_eq!(EXPECTED_OUTPUT, outbuf) + } +}