From e1e32322e39dbeee33532480f320b59f9e4d9901 Mon Sep 17 00:00:00 2001 From: Programatik Date: Fri, 20 May 2022 18:30:10 +0300 Subject: [PATCH] change implementation and add size hint Co-authored-by: neoeinstein --- src/limited.rs | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/limited.rs b/src/limited.rs index 1546c6f..e461b5b 100644 --- a/src/limited.rs +++ b/src/limited.rs @@ -1,6 +1,6 @@ //! Body types. -use crate::Body; +use crate::{Body, SizeHint}; use bytes::Buf; use pin_project_lite::pin_project; use std::{ @@ -16,8 +16,7 @@ pin_project! { pub struct Limited { #[pin] inner: B, - limit: usize, - read: usize, + remaining: usize, } } @@ -26,8 +25,7 @@ impl Limited { pub fn new(inner: B, limit: usize) -> Self { Self { inner, - limit, - read: 0, + remaining: limit, } } } @@ -46,20 +44,22 @@ where ) -> Poll>> { let this = self.project(); - match this.inner.poll_data(cx) { + let res = match this.inner.poll_data(cx) { Poll::Ready(Some(Ok(data))) => { - *this.read += data.remaining(); - - if this.read <= this.limit { - Poll::Ready(Some(Ok(data))) + if data.remaining() > *this.remaining { + *this.remaining = 0; + Some(Err("length limit exceeded".into())) } else { - Poll::Ready(Some(Err("body limit exceeded".into()))) + *this.remaining -= data.remaining(); + Some(Ok(data)) } } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } + Poll::Ready(Some(Err(e))) => Some(Err(e.into())), + Poll::Ready(None) => None, + Poll::Pending => return Poll::Pending, + }; + + Poll::Ready(res) } fn poll_trailers( @@ -72,6 +72,24 @@ where fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } + + fn size_hint(&self) -> SizeHint { + use std::convert::TryFrom; + match u64::try_from(self.remaining) { + Ok(n) => { + let mut hint = self.inner.size_hint(); + if hint.lower() >= n { + hint.set_exact(n) + } else if let Some(max) = hint.upper() { + hint.set_upper(n.min(max)) + } else { + hint.set_upper(n) + } + hint + } + Err(_) => self.inner.size_hint(), + } + } } #[cfg(test)]