From 6164e76405935065aeb912f94ba94230e0bac60f Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 24 Jun 2017 22:44:02 -0700 Subject: [PATCH] feat(server): Handle 100-continue cc #838 --- src/http/conn.rs | 75 ++++++++++++++++++++++++++++++++++++------------ src/http/io.rs | 4 +++ src/http/mod.rs | 29 ++++++++++++++++++- tests/server.rs | 30 +++++++++++++++++++ 4 files changed, 118 insertions(+), 20 deletions(-) diff --git a/src/http/conn.rs b/src/http/conn.rs index 9fc6816eba..85cbe94f39 100644 --- a/src/http/conn.rs +++ b/src/http/conn.rs @@ -62,6 +62,13 @@ where I: AsyncRead + AsyncWrite, } } + fn can_write_continue(&self) -> bool { + match self.state.writing { + Writing::Continue(..) => true, + _ => false, + } + } + fn can_read_body(&self) -> bool { match self.state.reading { Reading::Body(..) => true, @@ -105,6 +112,10 @@ where I: AsyncRead + AsyncWrite, } }; self.state.busy(); + if head.expecting_continue() { + let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.state.writing = Writing::Continue(Cursor::new(msg)); + } let wants_keep_alive = head.should_keep_alive(); self.state.keep_alive &= wants_keep_alive; let (body, reading) = if decoder.is_eof() { @@ -172,6 +183,7 @@ where I: AsyncRead + AsyncWrite, } match self.state.writing { + Writing::Continue(..) | Writing::Body(..) | Writing::Ending(..) => return, Writing::Init | @@ -191,7 +203,7 @@ where I: AsyncRead + AsyncWrite, fn can_write_head(&self) -> bool { match self.state.writing { - Writing::Init => true, + Writing::Continue(..) | Writing::Init => true, _ => false } } @@ -199,6 +211,7 @@ where I: AsyncRead + AsyncWrite, fn can_write_body(&self) -> bool { match self.state.writing { Writing::Body(..) => true, + Writing::Continue(..) | Writing::Init | Writing::Ending(..) | Writing::KeepAlive | @@ -227,6 +240,13 @@ where I: AsyncRead + AsyncWrite, let wants_keep_alive = head.should_keep_alive(); self.state.keep_alive &= wants_keep_alive; + // if a 100-continue has started but not finished sending, tack the + // remainder on to the start of the buffer. + if let Writing::Continue(ref pending) = self.state.writing { + if pending.has_started() { + self.io.write_buf_mut().extend_from_slice(pending.buf()); + } + } let encoder = T::encode(head, self.io.write_buf_mut()); self.state.writing = if body { Writing::Body(encoder, None) @@ -290,6 +310,15 @@ where I: AsyncRead + AsyncWrite, fn write_queued(&mut self) -> Poll<(), io::Error> { trace!("Conn::write_queued()"); let state = match self.state.writing { + Writing::Continue(ref mut queued) => { + let n = self.io.buffer(queued.buf()); + queued.consume(n); + if queued.is_written() { + Writing::Init + } else { + return Ok(Async::NotReady); + } + } Writing::Body(ref mut encoder, ref mut queued) => { let complete = if let Some(chunk) = queued.as_mut() { let n = try_nb!(encoder.encode(&mut self.io, chunk.buf())); @@ -349,24 +378,28 @@ where I: AsyncRead + AsyncWrite, trace!("Conn::poll()"); self.state.read_task.take(); - if self.is_read_closed() { - trace!("Conn::poll when closed"); - Ok(Async::Ready(None)) - } else if self.can_read_head() { - self.read_head() - } else if self.can_read_body() { - self.read_body() - .map(|async| async.map(|chunk| Some(Frame::Body { - chunk: chunk - }))) - .or_else(|err| { - self.state.close_read(); - Ok(Async::Ready(Some(Frame::Error { error: err.into() }))) - }) - } else { - trace!("poll when on keep-alive"); - self.maybe_park_read(); - Ok(Async::NotReady) + loop { + if self.is_read_closed() { + trace!("Conn::poll when closed"); + return Ok(Async::Ready(None)); + } else if self.can_read_head() { + return self.read_head(); + } else if self.can_write_continue() { + try_nb!(self.flush()); + } else if self.can_read_body() { + return self.read_body() + .map(|async| async.map(|chunk| Some(Frame::Body { + chunk: chunk + }))) + .or_else(|err| { + self.state.close_read(); + Ok(Async::Ready(Some(Frame::Error { error: err.into() }))) + }); + } else { + trace!("poll when on keep-alive"); + self.maybe_park_read(); + return Ok(Async::NotReady); + } } } } @@ -467,6 +500,7 @@ enum Reading { } enum Writing { + Continue(Cursor<&'static [u8]>), Init, Body(Encoder, Option>), Ending(Cursor<&'static [u8]>), @@ -488,6 +522,9 @@ impl, K: fmt::Debug> fmt::Debug for State { impl> fmt::Debug for Writing { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Writing::Continue(ref buf) => f.debug_tuple("Continue") + .field(buf) + .finish(), Writing::Init => f.write_str("Init"), Writing::Body(ref enc, ref queued) => f.debug_tuple("Body") .field(enc) diff --git a/src/http/io.rs b/src/http/io.rs index ab26a01337..0da9b9597c 100644 --- a/src/http/io.rs +++ b/src/http/io.rs @@ -181,6 +181,10 @@ impl> Cursor { } } + pub fn has_started(&self) -> bool { + self.pos != 0 + } + pub fn is_written(&self) -> bool { trace!("Cursor::is_written pos = {}, len = {}", self.pos, self.bytes.as_ref().len()); self.pos >= self.bytes.as_ref().len() diff --git a/src/http/mod.rs b/src/http/mod.rs index 6f8644ce9d..b678e97dcf 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -4,7 +4,7 @@ use std::fmt; use bytes::BytesMut; -use header::{Connection, ConnectionOption}; +use header::{Connection, ConnectionOption, Expect}; use header::Headers; use method::Method; use status::StatusCode; @@ -56,6 +56,10 @@ impl MessageHead { pub fn should_keep_alive(&self) -> bool { should_keep_alive(self.version, &self.headers) } + + pub fn expecting_continue(&self) -> bool { + expecting_continue(self.version, &self.headers) + } } impl ResponseHead { @@ -119,6 +123,17 @@ pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { ret } +/// Checks if a connection is expecting a `100 Continue` before sending its body. +#[inline] +pub fn expecting_continue(version: HttpVersion, headers: &Headers) -> bool { + let ret = match (version, headers.get::()) { + (Http11, Some(&Expect::Continue)) => true, + _ => false + }; + trace!("expecting_continue(version={:?}, header={:?}) = {:?}", version, headers.get::(), ret); + ret +} + #[derive(Debug)] pub enum ServerTransaction {} @@ -168,3 +183,15 @@ fn test_should_keep_alive() { assert!(should_keep_alive(Http10, &headers)); assert!(should_keep_alive(Http11, &headers)); } + +#[test] +fn test_expecting_continue() { + let mut headers = Headers::new(); + + assert!(!expecting_continue(Http10, &headers)); + assert!(!expecting_continue(Http11, &headers)); + + headers.set(Expect::Continue); + assert!(!expecting_continue(Http10, &headers)); + assert!(expecting_continue(Http11, &headers)); +} diff --git a/tests/server.rs b/tests/server.rs index dae3fca340..e5c38289f0 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -523,3 +523,33 @@ fn test_server_disable_keep_alive() { } } } + +#[test] +fn expect_continue() { + let server = serve(); + let mut req = connect(server.addr()); + server.reply().status(hyper::Ok); + + req.write_all(b"\ + POST /foo HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 5\r\n\ + Connection: Close\r\n\ + \r\n\ + ").expect("write 1"); + + let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; + let mut buf = vec![0; msg.len()]; + req.read_exact(&mut buf).expect("read 1"); + assert_eq!(buf, msg); + + let msg = b"hello"; + req.write_all(msg).expect("write 2"); + + let mut body = String::new(); + req.read_to_string(&mut body).expect("read 2"); + + let body = server.body(); + assert_eq!(body, msg); +}