Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server support for push #327

Merged
merged 11 commits into from
Sep 16, 2019
110 changes: 78 additions & 32 deletions src/frame/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{StreamDependency, StreamId};
use frame::{Error, Frame, Head, Kind};
use hpack;

use http::{uri, HeaderMap, Method, StatusCode, Uri};
use http::{uri, Request, HeaderMap, Method, StatusCode, Uri};
use http::header::{self, HeaderName, HeaderValue};

use byteorder::{BigEndian, ByteOrder};
Expand Down Expand Up @@ -283,9 +283,86 @@ impl fmt::Debug for Headers {
}
}

// ===== util =====

pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is used in multiple locations now (I believe), I would consider creating a private util module at the root and moving this there.

if src.len() > 19 {
// At danger for overflow...
return Err(());
}

let mut ret = 0;

for &d in src {
if d < b'0' || d > b'9' {
return Err(());
}

ret *= 10;
ret += (d - b'0') as u64;
}

Ok(ret)
}

// ===== impl PushPromise =====

impl PushPromise {
pub fn new(
stream_id: StreamId,
promised_id: StreamId,
pseudo: Pseudo,
fields: HeaderMap,
) -> Self {
PushPromise {
flags: PushPromiseFlag::default(),
header_block: HeaderBlock {
fields,
is_over_size: false,
pseudo,
},
promised_id,
stream_id,
}
}

pub fn validate_request(req: &Request<()>) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cleanup work is good standalone. It might be worth extracting it to a dedicated PR in order to simplify this PR (and make reviewing easier). Up to you though! Whatever makes it easier for you.

// The spec has some requirements for promised request headers
// [https://httpwg.org/specs/rfc7540.html#PushRequests]

// A promised request "that indicates the presence of a request body
// MUST reset the promised stream with a stream error"
if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
match parse_u64(content_length.as_bytes()) {
michaelbeaumont marked this conversation as resolved.
Show resolved Hide resolved
Ok(0) => {},
_ => return false,
}
}
// "The server MUST include a method in the :method pseudo-header field
// that is safe and cacheable"
if !Self::safe_and_cacheable(req.method()) {
return false;
}

true
}

fn safe_and_cacheable(method: &Method) -> bool {
// Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
// Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
return method == Method::GET || method == Method::HEAD;
}


pub fn fields(&self) -> &HeaderMap {
&self.header_block.fields
}

#[cfg(feature = "unstable")]
pub fn into_fields(self) -> HeaderMap {
self.header_block.fields
}

/// Loads the push promise frame but doesn't actually do HPACK decoding.
///
/// HPACK decoding is done in the `load_hpack` step.
Expand Down Expand Up @@ -378,44 +455,13 @@ impl PushPromise {
fn head(&self) -> Head {
Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
}
}

impl PushPromise {
/// Consume `self`, returning the parts of the frame
pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields)
}
}

#[cfg(feature = "unstable")]
impl PushPromise {
pub fn new(
stream_id: StreamId,
promised_id: StreamId,
pseudo: Pseudo,
fields: HeaderMap,
) -> Self {
PushPromise {
flags: PushPromiseFlag::default(),
header_block: HeaderBlock {
fields,
is_over_size: false,
pseudo,
},
promised_id,
stream_id,
}
}

pub fn fields(&self) -> &HeaderMap {
&self.header_block.fields
}

pub fn into_fields(self) -> HeaderMap {
self.header_block.fields
}
}

impl<T> From<PushPromise> for Frame<T> {
fn from(src: PushPromise) -> Self {
Frame::PushPromise(src)
Expand Down
2 changes: 1 addition & 1 deletion src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mod window_update;
pub use self::data::Data;
pub use self::go_away::GoAway;
pub use self::head::{Head, Kind};
pub use self::headers::{Continuation, Headers, Pseudo, PushPromise};
pub use self::headers::{Continuation, Headers, Pseudo, PushPromise, parse_u64};
pub use self::ping::Ping;
pub use self::priority::{Priority, StreamDependency};
pub use self::reason::Reason;
Expand Down
55 changes: 5 additions & 50 deletions src/proto/streams/recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use {frame, proto};
use codec::{RecvError, UserError};
use frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE};

use http::{HeaderMap, Response, Request, Method};
use http::{HeaderMap, Response, Request};

use std::io;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -176,7 +176,7 @@ impl Recv {
use http::header;

if let Some(content_length) = frame.fields().get(header::CONTENT_LENGTH) {
let content_length = match parse_u64(content_length.as_bytes()) {
let content_length = match frame::parse_u64(content_length.as_bytes()) {
Ok(v) => v,
Err(_) => {
return Err(RecvError::Stream {
Expand Down Expand Up @@ -592,45 +592,22 @@ impl Recv {
}

let promised_id = frame.promised_id();
use http::header;
let (pseudo, fields) = frame.into_parts();
let req = ::server::Peer::convert_poll_message(pseudo, fields, promised_id)?;
// The spec has some requirements for promised request headers
// [https://httpwg.org/specs/rfc7540.html#PushRequests]

// A promised request "that indicates the presence of a request body
// MUST reset the promised stream with a stream error"
if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
match parse_u64(content_length.as_bytes()) {
Ok(0) => {},
_ => {
return Err(RecvError::Stream {
id: promised_id,
reason: Reason::PROTOCOL_ERROR,
});
},
}
}
// "The server MUST include a method in the :method pseudo-header field
// that is safe and cacheable"
if !Self::safe_and_cacheable(req.method()) {

if !frame::PushPromise::validate_request(&req) {
return Err(RecvError::Stream {
id: promised_id,
reason: Reason::PROTOCOL_ERROR,
});
}

use super::peer::PollMessage::*;
stream.pending_recv.push_back(&mut self.buffer, Event::Headers(Server(req)));
stream.notify_recv();
Ok(())
}

fn safe_and_cacheable(method: &Method) -> bool {
// Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
// Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
return method == Method::GET || method == Method::HEAD;
}

/// Ensures that `id` is not in the `Idle` state.
pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
if let Ok(next) = self.next_stream_id {
Expand Down Expand Up @@ -992,25 +969,3 @@ impl<T> From<RecvError> for RecvHeaderBlockError<T> {
RecvHeaderBlockError::State(err)
}
}

// ===== util =====

fn parse_u64(src: &[u8]) -> Result<u64, ()> {
if src.len() > 19 {
// At danger for overflow...
return Err(());
}

let mut ret = 0;

for &d in src {
if d < b'0' || d > b'9' {
return Err(());
}

ret *= 10;
ret += (d - b'0') as u64;
}

Ok(ret)
}
65 changes: 49 additions & 16 deletions src/proto/streams/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,54 @@ impl Send {
Ok(stream_id)
}

pub fn reserve_local(&mut self) -> Result<StreamId, UserError> {
let stream_id = self.ensure_next_stream_id()?;
self.next_stream_id = stream_id.next_id();
Ok(stream_id)
}

fn check_headers(
fields: &http::HeaderMap
) -> Result<(), UserError> {
// 8.1.2.2. Connection-Specific Header Fields
if fields.contains_key(http::header::CONNECTION)
|| fields.contains_key(http::header::TRANSFER_ENCODING)
|| fields.contains_key(http::header::UPGRADE)
|| fields.contains_key("keep-alive")
|| fields.contains_key("proxy-connection")
{
debug!("illegal connection-specific headers found");
return Err(UserError::MalformedHeaders);
} else if let Some(te) = fields.get(http::header::TE) {
if te != "trailers" {
debug!("illegal connection-specific headers found");
return Err(UserError::MalformedHeaders);
}
}
Ok(())
}

pub fn send_push_promise<B>(
&mut self,
frame: frame::PushPromise,
buffer: &mut Buffer<Frame<B>>,
stream: &mut store::Ptr,
task: &mut Option<Task>,
) -> Result<(), UserError> {
trace!(
"send_push_promise; frame={:?}; init_window={:?}",
frame,
self.init_window_sz
);

Self::check_headers(frame.fields())?;

// Queue the frame for sending
self.prioritize.queue_frame(frame.into(), buffer, stream, task);

Ok(())
}

pub fn send_headers<B>(
&mut self,
frame: frame::Headers,
Expand All @@ -68,22 +116,7 @@ impl Send {
self.init_window_sz
);

// 8.1.2.2. Connection-Specific Header Fields
if frame.fields().contains_key(http::header::CONNECTION)
|| frame.fields().contains_key(http::header::TRANSFER_ENCODING)
|| frame.fields().contains_key(http::header::UPGRADE)
|| frame.fields().contains_key("keep-alive")
|| frame.fields().contains_key("proxy-connection")
{
debug!("illegal connection-specific headers found");
return Err(UserError::MalformedHeaders);
} else if let Some(te) = frame.fields().get(http::header::TE) {
if te != "trailers" {
debug!("illegal connection-specific headers found");
return Err(UserError::MalformedHeaders);

}
}
Self::check_headers(frame.fields())?;

let end_stream = frame.is_end_stream();

Expand Down
19 changes: 15 additions & 4 deletions src/proto/streams/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub struct State {
enum Inner {
Idle,
// TODO: these states shouldn't count against concurrency limits:
//ReservedLocal,
ReservedLocal,
ReservedRemote,
Open { local: Peer, remote: Peer },
HalfClosedLocal(Peer), // TODO: explicitly name this value
Expand Down Expand Up @@ -113,7 +113,7 @@ impl State {
remote,
}
},
HalfClosedRemote(AwaitingHeaders) => if eos {
HalfClosedRemote(AwaitingHeaders) | ReservedLocal => if eos {
Closed(Cause::EndStream)
} else {
HalfClosedRemote(local)
Expand Down Expand Up @@ -192,6 +192,17 @@ impl State {
}
}

/// Transition from Idle -> ReservedLocal
pub fn reserve_local(&mut self) -> Result<(), UserError> {
match self.inner {
Idle => {
self.inner = ReservedLocal;
Ok(())
},
_ => Err(UserError::UnexpectedFrameType),
}
}

/// Indicates that the remote side will not send more data to the local.
pub fn recv_close(&mut self) -> Result<(), RecvError> {
match self.inner {
Expand Down Expand Up @@ -378,7 +389,7 @@ impl State {

pub fn is_recv_closed(&self) -> bool {
match self.inner {
Closed(..) | HalfClosedRemote(..) => true,
Closed(..) | HalfClosedRemote(..) | ReservedLocal => true,
_ => false,
}
}
Expand All @@ -405,7 +416,7 @@ impl State {
Closed(Cause::Scheduled(reason)) => Err(proto::Error::Proto(reason)),
Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())),
Closed(Cause::EndStream) |
HalfClosedRemote(..) => Ok(false),
HalfClosedRemote(..) | ReservedLocal => Ok(false),
_ => Ok(true),
}
}
Expand Down
Loading