diff --git a/ssh-encoding/src/base64/reader.rs b/ssh-encoding/src/base64/reader.rs index 15b6c85..5fcee08 100644 --- a/ssh-encoding/src/base64/reader.rs +++ b/ssh-encoding/src/base64/reader.rs @@ -1,13 +1,17 @@ //! Base64 reader support (constant-time). -use crate::{Reader, Result}; +use crate::{Decode, Error, Reader, Result}; /// Inner constant-time Base64 reader type from the `base64ct` crate. type Inner<'i> = base64ct::Decoder<'i, base64ct::Base64>; /// Constant-time Base64 reader implementation. pub struct Base64Reader<'i> { + /// Inner Base64 reader. inner: Inner<'i>, + + /// Custom length of remaining data, used for nested length-prefixed reading. + remaining_len: usize, } impl<'i> Base64Reader<'i> { @@ -18,18 +22,52 @@ impl<'i> Base64Reader<'i> { /// - `Ok(reader)` on success. /// - `Err(Error::Base64)` if the input buffer is empty. pub fn new(input: &'i [u8]) -> Result { + let inner = Inner::new(input)?; + let remaining_len = inner.remaining_len(); + Ok(Self { - inner: Inner::new(input)?, + inner, + remaining_len, }) } } impl Reader for Base64Reader<'_> { fn read<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { - Ok(self.inner.decode(out)?) + if out.is_empty() { + return Ok(out); + } + + let remaining_len = self + .remaining_len + .checked_sub(out.len()) + .ok_or(Error::Length)?; + + let ret = self.inner.decode(out)?; + self.remaining_len = remaining_len; + Ok(ret) + } + + fn read_prefixed(&mut self, f: F) -> core::result::Result + where + E: From, + F: FnOnce(&mut Self) -> core::result::Result, + { + let prefix_len = usize::decode(self)?; + let new_remaining_len = self + .remaining_len + .checked_sub(prefix_len) + .ok_or(Error::Length)?; + + self.remaining_len = prefix_len; + let ret = f(self)?; + self.ensure_finished()?; + + self.remaining_len = new_remaining_len; + Ok(ret) } fn remaining_len(&self) -> usize { - self.inner.remaining_len() + self.remaining_len } } diff --git a/ssh-encoding/src/lib.rs b/ssh-encoding/src/lib.rs index a8adbbb..a865d67 100644 --- a/ssh-encoding/src/lib.rs +++ b/ssh-encoding/src/lib.rs @@ -42,7 +42,7 @@ pub use crate::{ encode::Encode, error::{Error, Result}, label::{Label, LabelError}, - reader::{NestedReader, Reader}, + reader::Reader, writer::Writer, }; diff --git a/ssh-encoding/src/pem/reader.rs b/ssh-encoding/src/pem/reader.rs index 628ccc0..8e14616 100644 --- a/ssh-encoding/src/pem/reader.rs +++ b/ssh-encoding/src/pem/reader.rs @@ -1,19 +1,29 @@ use super::LINE_WIDTH; -use crate::{Reader, Result}; +use crate::{Decode, Error, Reader, Result}; /// Inner PEM decoder. type Inner<'i> = pem_rfc7468::Decoder<'i>; /// Constant-time PEM reader. pub struct PemReader<'i> { + /// Inner PEM reader. inner: Inner<'i>, + + /// Custom length of remaining data, used for nested length-prefixed reading. + remaining_len: usize, } impl<'i> PemReader<'i> { - /// TODO + /// Create a new PEM reader. + /// + /// Uses [`LINE_WIDTH`] as the default line width (i.e. 70 chars). pub fn new(pem: &'i [u8]) -> Result { + let inner = Inner::new_wrapped(pem, LINE_WIDTH)?; + let remaining_len = inner.remaining_len(); + Ok(Self { - inner: Inner::new_wrapped(pem, LINE_WIDTH)?, + inner, + remaining_len, }) } @@ -25,10 +35,40 @@ impl<'i> PemReader<'i> { impl Reader for PemReader<'_> { fn read<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { - Ok(self.inner.decode(out)?) + if out.is_empty() { + return Ok(out); + } + + let remaining_len = self + .remaining_len + .checked_sub(out.len()) + .ok_or(Error::Length)?; + + let ret = self.inner.decode(out)?; + self.remaining_len = remaining_len; + Ok(ret) + } + + fn read_prefixed(&mut self, f: F) -> core::result::Result + where + E: From, + F: FnOnce(&mut Self) -> core::result::Result, + { + let prefix_len = usize::decode(self)?; + let new_remaining_len = self + .remaining_len + .checked_sub(prefix_len) + .ok_or(Error::Length)?; + + self.remaining_len = prefix_len; + let ret = f(self)?; + self.ensure_finished()?; + + self.remaining_len = new_remaining_len; + Ok(ret) } fn remaining_len(&self) -> usize { - self.inner.remaining_len() + self.remaining_len } } diff --git a/ssh-encoding/src/reader.rs b/ssh-encoding/src/reader.rs index 5c15d47..e9b814e 100644 --- a/ssh-encoding/src/reader.rs +++ b/ssh-encoding/src/reader.rs @@ -29,18 +29,10 @@ pub trait Reader: Sized { /// Decodes a `uint32` which identifies the length of some encapsulated /// data, then calls the given reader function with the length of the /// remaining data. - fn read_prefixed<'r, T, E, F>(&'r mut self, f: F) -> core::result::Result + fn read_prefixed(&mut self, f: F) -> core::result::Result where E: From, - F: FnOnce(&mut NestedReader<'r, Self>) -> core::result::Result, - { - let len = usize::decode(self)?; - - f(&mut NestedReader { - inner: self, - remaining_len: len, - }) - } + F: FnOnce(&mut Self) -> core::result::Result; /// Decodes `[u8]` from `byte[n]` as described in [RFC4251 ยง 5]: /// @@ -111,17 +103,27 @@ pub trait Reader: Sized { }) } - /// Finish decoding, returning the given value if there is no remaining - /// data, or an error otherwise. - fn finish(self, value: T) -> Result { + /// Ensure that decoding is finished. + /// + /// # Errors + /// + /// - Returns `Error::TrailingData` if there is data remaining in the encoder. + fn ensure_finished(&self) -> Result<()> { if self.is_finished() { - Ok(value) + Ok(()) } else { Err(Error::TrailingData { remaining: self.remaining_len(), }) } } + + /// Finish decoding, returning the given value if there is no remaining + /// data, or an error otherwise. + fn finish(self, value: T) -> Result { + self.ensure_finished()?; + Ok(value) + } } impl Reader for &[u8] { @@ -136,37 +138,24 @@ impl Reader for &[u8] { } } - fn remaining_len(&self) -> usize { - self.len() - } -} - -/// Reader type used by [`Reader::read_prefixed`]. -pub struct NestedReader<'r, R: Reader> { - /// Inner reader type. - inner: &'r mut R, - - /// Remaining length in the prefixed reader. - remaining_len: usize, -} + fn read_prefixed(&mut self, f: F) -> core::result::Result + where + E: From, + F: FnOnce(&mut Self) -> core::result::Result, + { + let prefix_len = usize::decode(self)?; -impl<'r, R: Reader> Reader for NestedReader<'r, R> { - fn read<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { - if out.is_empty() { - return Ok(out); + if self.len() < prefix_len { + return Err(Error::Length.into()); } - let remaining_len = self - .remaining_len - .checked_sub(out.len()) - .ok_or(Error::Length)?; - - let ret = self.inner.read(out)?; - self.remaining_len = remaining_len; + let (mut prefix, remaining) = self.split_at(prefix_len); + let ret = f(&mut prefix)?; + *self = remaining; Ok(ret) } fn remaining_len(&self) -> usize { - self.remaining_len + self.len() } }