Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making Soketto easier to use with http types (and hyper) #48

Merged
merged 13 commits into from
Sep 23, 2021
Merged
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Rust

on:
push:
# Run jobs when commits are pushed to
# Run jobs when commits are pushed to
# develop or release-like branches:
branches:
- develop
Expand Down Expand Up @@ -40,7 +40,7 @@ jobs:
uses: actions-rs/cargo@v1.0.3
with:
command: check
args: --all-targets
args: --all-targets --all-features
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have a feature, make sure CI checks the code behind it


fmt:
name: Run rustfmt
Expand Down
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ httparse = { default-features = false, features = ["std"], version = "1.3.4" }
log = { default-features = false, version = "0.4.8" }
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" }
sha-1 = { default-features = false, version = "0.9" }
http = { default-features = false, version = "0.2", optional = true }

[dev-dependencies]
quickcheck = "0.9"
tokio = { version = "1", features = ["full"] }
tokio-util = { version = "0.6", features = ["compat"] }
tokio-stream = { version = "0.1", features = ["net"] }
hyper = { version = "0.14.10", features = ["full"] }
env_logger = "0.9.0"

[[example]]
name = "hyper_server"
required-features = ["http"]
Copy link
Contributor Author

@jsdw jsdw Sep 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just makes the error messages nicer when trying to run the example.


129 changes: 48 additions & 81 deletions examples/hyper_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,71 +23,79 @@

use futures::io::{BufReader, BufWriter};
use hyper::{Body, Request, Response};
use soketto::{handshake, BoxedError};
use soketto::{
handshake::http::{is_upgrade_request, Server},
BoxedError,
};
use tokio_util::compat::TokioAsyncReadCompatExt;

/// Start up a hyper server.
#[tokio::main]
async fn main() -> Result<(), BoxedError> {
env_logger::init();

let addr = ([127, 0, 0, 1], 3000).into();

let service =
hyper::service::make_service_fn(|_| async { Ok::<_, hyper::Error>(hyper::service::service_fn(handler)) });
let server = hyper::Server::bind(&addr).serve(service);

println!("Listening on http://{}", server.local_addr());
println!("Listening on http://{} — connect and I'll echo back anything you send!", server.local_addr());
server.await?;

Ok(())
}

jsdw marked this conversation as resolved.
Show resolved Hide resolved
/// Handle incoming HTTP Requests.
async fn handler(req: Request<Body>) -> Result<hyper::Response<Body>, BoxedError> {
// If the request is asking to be upgraded to a websocket connection, we do that,
// and we handle the websocket connection (in this case, by echoing messages back):
if is_upgrade_request(&req) {
let (res, on_upgrade) = upgrade_to_websocket(req)?;
tokio::spawn(async move {
if let Err(e) = websocket_echo_messages(on_upgrade).await {
eprintln!("Error upgrading to websocket connection: {}", e);
// Create a new handshake server.
let mut server = Server::new();

// Add any extensions that we want to use.
#[cfg(feature = "deflate")]
{
let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server);
server.add_extension(Box::new(deflate));
}

// Attempt the handshake.
match server.receive_request(&req) {
// The handshake has been successful so far; return the response we're given back
// and spawn a task to handle the long-running WebSocket server:
Ok(response) => {
tokio::spawn(async move {
if let Err(e) = websocket_echo_messages(server, req).await {
log::error!("Error upgrading to websocket connection: {}", e);
}
});
Ok(response.map(|()| Body::empty()))
}
});
Ok(res)
}
// Or, we can handle the request as a standard HTTP request:
else {
// We tried to upgrade and failed early on; tell the client about the failure however we like:
Err(e) => {
log::error!("Could not upgrade connection: {}", e);
Ok(Response::new(Body::from("Something went wrong upgrading!")))
}
}
} else {
// The request wasn't an upgrade request; let's treat it as a standard HTTP request:
Ok(Response::new(Body::from("Hello HTTP!")))
}
}

/// Return the response to the upgrade request, and a way to get hold of the underlying TCP stream
fn upgrade_to_websocket(req: Request<Body>) -> Result<(Response<Body>, hyper::upgrade::OnUpgrade), handshake::Error> {
let key = req.headers().get("Sec-WebSocket-Key").ok_or(handshake::Error::InvalidSecWebSocketAccept)?;
if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
return Err(handshake::Error::HeaderNotFound("Sec-WebSocket-Version".into()));
}

// Just a little ceremony we need to go through to return the correct response key:
let mut accept_key_buf = [0; 32];
let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf);

let response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header("Sec-WebSocket-Accept", accept_key)
.body(Body::empty())
.expect("bug: failed to build response");

Ok((response, hyper::upgrade::on(req)))
}

/// Echo any messages we get from the client back to them
async fn websocket_echo_messages(on_upgrade: hyper::upgrade::OnUpgrade) -> Result<(), BoxedError> {
// Wait for the request to upgrade, and pass the stream we get back to Soketto to handle the WS connection:
let stream = on_upgrade.await?;
let server = handshake::Server::new(BufReader::new(BufWriter::new(stream.compat())));
let (mut sender, mut receiver) = server.into_builder().finish();
async fn websocket_echo_messages(server: Server, req: Request<Body>) -> Result<(), BoxedError> {
// The negotiation to upgrade to a WebSocket connection has been successful so far. Next, we get back the underlying
// stream using `hyper::upgrade::on`, and hand this to a Soketto server to use to handle the WebSocket communication
// on this socket.
//
// Note: awaiting this won't succeed until the handshake response has been returned to the client, so this must be
// spawned on a separate task so as not to block that response being handed back.
Comment on lines +92 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good, without this comment I'd have missed how this works.

let stream = hyper::upgrade::on(req).await?;
let stream = BufReader::new(BufWriter::new(stream.compat()));

// Get back a reader and writer that we can use to send and receive websocket messages.
let (mut sender, mut receiver) = server.into_builder(stream).finish();

// Echo any received messages back to the client:
let mut message = Vec::new();
Expand Down Expand Up @@ -118,44 +126,3 @@ async fn websocket_echo_messages(on_upgrade: hyper::upgrade::OnUpgrade) -> Resul

Ok(())
}

/// Defined in RFC 6455. this is how we convert the Sec-WebSocket-Key in a request into a
/// Sec-WebSocket-Accept that we return in the response.
fn generate_websocket_accept_key<'a>(key: &[u8], buf: &'a mut [u8; 32]) -> &'a [u8] {
// Defined in RFC 6455, we append this to the key to generate the response:
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

use sha1::{Digest, Sha1};
let mut digest = Sha1::new();
digest.update(key);
digest.update(KEY);
let d = digest.finalize();

let n = base64::encode_config_slice(&d, base64::STANDARD, buf);
&buf[..n]
}

/// Check if a request is a websocket upgrade request.
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket")
}

/// Check if there is a header of the given name containing the wanted value.
fn header_contains_value(headers: &hyper::HeaderMap, header: hyper::header::HeaderName, value: &[u8]) -> bool {
pub fn trim(x: &[u8]) -> &[u8] {
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
Some(i) => i,
None => return &[],
};
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
&x[from..=to]
}

for header in headers.get_all(header) {
if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
return true;
}
}
false
}
30 changes: 28 additions & 2 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4

pub mod client;
#[cfg(feature = "http")]
pub mod http;
pub mod server;

use crate::extension::{Extension, Param};
use bytes::BytesMut;
use sha1::{Digest, Sha1};
use std::{fmt, io, str};

pub use client::{Client, ServerResponse};
Expand Down Expand Up @@ -105,7 +108,15 @@ where
bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ")
}

while let Some(e) = iter.next() {
append_extension_header_value(iter, bytes)
}

// Write the extension header value to the given buffer.
fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable<I>, bytes: &mut BytesMut)
where
I: Iterator<Item = &'a Box<dyn Extension + Send>>,
{
while let Some(e) = extensions_iter.next() {
bytes.extend_from_slice(e.name().as_bytes());
for p in e.params() {
bytes.extend_from_slice(b"; ");
Expand All @@ -115,12 +126,27 @@ where
bytes.extend_from_slice(v.as_bytes())
}
}
if iter.peek().is_some() {
if extensions_iter.peek().is_some() {
bytes.extend_from_slice(b", ")
}
}
}

// This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via
// the `Sec-WebSocket-Key` header during a websocket handshake, and a 32 byte output buffer, and
// writes the response that's expected to be handed back in the response header `Sec-WebSocket-Accept`.
//
// See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this.
fn generate_accept_key<'k>(key_base64: &[u8; 24], output_buf: &'k mut [u8; 32]) -> &'k [u8] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: there is a WebSocketKey type alias defined here ([u8; 24]): the hash is always 16 bytes, and base64 encoding of it is always 24 bytes. You could skip the output_buf and just return WebSocketKey by value here.

Copy link
Contributor Author

@jsdw jsdw Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! I am now returning a 28 (not 32) byte array rather than taking in a buffer. Unless I'm mistaken, base64 encoding a 160bit hash will always use exactly 28 bytes!

let mut digest = Sha1::new();
digest.update(key_base64);
digest.update(KEY);
let d = digest.finalize();

let n = base64::encode_config_slice(&d, base64::STANDARD, output_buf);
&output_buf[..n]
}

/// Enumeration of possible handshake errors.
#[non_exhaustive]
#[derive(Debug)]
Expand Down
Loading