-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making Soketto easier to use with http
types (and hyper
)
#48
Changes from 9 commits
3f0fbda
6c98072
8695672
17987b7
a0593b2
41d7678
54ea9ce
5e80e21
e606700
c851479
719a782
57dc84c
555ac9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,10 +26,17 @@ httparse = { default-features = false, features = ["std"], version = "1.3.4" } | |
log = { default-features = false, version = "0.4.8" } | ||
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" } | ||
sha-1 = { default-features = false, version = "0.9" } | ||
http = { default-features = false, version = "0.2", optional = true } | ||
|
||
[dev-dependencies] | ||
quickcheck = "0.9" | ||
tokio = { version = "1", features = ["full"] } | ||
tokio-util = { version = "0.6", features = ["compat"] } | ||
tokio-stream = { version = "0.1", features = ["net"] } | ||
hyper = { version = "0.14.10", features = ["full"] } | ||
env_logger = "0.9.0" | ||
|
||
[[example]] | ||
name = "hyper_server" | ||
required-features = ["http"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just makes the error messages nicer when trying to run the example. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,71 +23,79 @@ | |
|
||
use futures::io::{BufReader, BufWriter}; | ||
use hyper::{Body, Request, Response}; | ||
use soketto::{handshake, BoxedError}; | ||
use soketto::{ | ||
handshake::http::{is_upgrade_request, Server}, | ||
BoxedError, | ||
}; | ||
use tokio_util::compat::TokioAsyncReadCompatExt; | ||
|
||
/// Start up a hyper server. | ||
#[tokio::main] | ||
async fn main() -> Result<(), BoxedError> { | ||
env_logger::init(); | ||
|
||
let addr = ([127, 0, 0, 1], 3000).into(); | ||
|
||
let service = | ||
hyper::service::make_service_fn(|_| async { Ok::<_, hyper::Error>(hyper::service::service_fn(handler)) }); | ||
let server = hyper::Server::bind(&addr).serve(service); | ||
|
||
println!("Listening on http://{}", server.local_addr()); | ||
println!("Listening on http://{} — connect and I'll echo back anything you send!", server.local_addr()); | ||
server.await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
jsdw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Handle incoming HTTP Requests. | ||
async fn handler(req: Request<Body>) -> Result<hyper::Response<Body>, BoxedError> { | ||
// If the request is asking to be upgraded to a websocket connection, we do that, | ||
// and we handle the websocket connection (in this case, by echoing messages back): | ||
if is_upgrade_request(&req) { | ||
let (res, on_upgrade) = upgrade_to_websocket(req)?; | ||
tokio::spawn(async move { | ||
if let Err(e) = websocket_echo_messages(on_upgrade).await { | ||
eprintln!("Error upgrading to websocket connection: {}", e); | ||
// Create a new handshake server. | ||
let mut server = Server::new(); | ||
|
||
// Add any extensions that we want to use. | ||
#[cfg(feature = "deflate")] | ||
{ | ||
let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); | ||
server.add_extension(Box::new(deflate)); | ||
} | ||
|
||
// Attempt the handshake. | ||
match server.receive_request(&req) { | ||
// The handshake has been successful so far; return the response we're given back | ||
// and spawn a task to handle the long-running WebSocket server: | ||
Ok(response) => { | ||
tokio::spawn(async move { | ||
if let Err(e) = websocket_echo_messages(server, req).await { | ||
log::error!("Error upgrading to websocket connection: {}", e); | ||
} | ||
}); | ||
Ok(response.map(|()| Body::empty())) | ||
} | ||
}); | ||
Ok(res) | ||
} | ||
// Or, we can handle the request as a standard HTTP request: | ||
else { | ||
// We tried to upgrade and failed early on; tell the client about the failure however we like: | ||
Err(e) => { | ||
log::error!("Could not upgrade connection: {}", e); | ||
Ok(Response::new(Body::from("Something went wrong upgrading!"))) | ||
} | ||
} | ||
} else { | ||
// The request wasn't an upgrade request; let's treat it as a standard HTTP request: | ||
Ok(Response::new(Body::from("Hello HTTP!"))) | ||
} | ||
} | ||
|
||
/// Return the response to the upgrade request, and a way to get hold of the underlying TCP stream | ||
fn upgrade_to_websocket(req: Request<Body>) -> Result<(Response<Body>, hyper::upgrade::OnUpgrade), handshake::Error> { | ||
let key = req.headers().get("Sec-WebSocket-Key").ok_or(handshake::Error::InvalidSecWebSocketAccept)?; | ||
if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") { | ||
return Err(handshake::Error::HeaderNotFound("Sec-WebSocket-Version".into())); | ||
} | ||
|
||
// Just a little ceremony we need to go through to return the correct response key: | ||
let mut accept_key_buf = [0; 32]; | ||
let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf); | ||
|
||
let response = Response::builder() | ||
.status(hyper::StatusCode::SWITCHING_PROTOCOLS) | ||
.header(hyper::header::CONNECTION, "upgrade") | ||
.header(hyper::header::UPGRADE, "websocket") | ||
.header("Sec-WebSocket-Accept", accept_key) | ||
.body(Body::empty()) | ||
.expect("bug: failed to build response"); | ||
|
||
Ok((response, hyper::upgrade::on(req))) | ||
} | ||
|
||
/// Echo any messages we get from the client back to them | ||
async fn websocket_echo_messages(on_upgrade: hyper::upgrade::OnUpgrade) -> Result<(), BoxedError> { | ||
// Wait for the request to upgrade, and pass the stream we get back to Soketto to handle the WS connection: | ||
let stream = on_upgrade.await?; | ||
let server = handshake::Server::new(BufReader::new(BufWriter::new(stream.compat()))); | ||
let (mut sender, mut receiver) = server.into_builder().finish(); | ||
async fn websocket_echo_messages(server: Server, req: Request<Body>) -> Result<(), BoxedError> { | ||
// The negotiation to upgrade to a WebSocket connection has been successful so far. Next, we get back the underlying | ||
// stream using `hyper::upgrade::on`, and hand this to a Soketto server to use to handle the WebSocket communication | ||
// on this socket. | ||
// | ||
// Note: awaiting this won't succeed until the handshake response has been returned to the client, so this must be | ||
// spawned on a separate task so as not to block that response being handed back. | ||
Comment on lines
+92
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is good, without this comment I'd have missed how this works. |
||
let stream = hyper::upgrade::on(req).await?; | ||
let stream = BufReader::new(BufWriter::new(stream.compat())); | ||
|
||
// Get back a reader and writer that we can use to send and receive websocket messages. | ||
let (mut sender, mut receiver) = server.into_builder(stream).finish(); | ||
|
||
// Echo any received messages back to the client: | ||
let mut message = Vec::new(); | ||
|
@@ -118,44 +126,3 @@ async fn websocket_echo_messages(on_upgrade: hyper::upgrade::OnUpgrade) -> Resul | |
|
||
Ok(()) | ||
} | ||
|
||
/// Defined in RFC 6455. this is how we convert the Sec-WebSocket-Key in a request into a | ||
/// Sec-WebSocket-Accept that we return in the response. | ||
fn generate_websocket_accept_key<'a>(key: &[u8], buf: &'a mut [u8; 32]) -> &'a [u8] { | ||
// Defined in RFC 6455, we append this to the key to generate the response: | ||
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | ||
|
||
use sha1::{Digest, Sha1}; | ||
let mut digest = Sha1::new(); | ||
digest.update(key); | ||
digest.update(KEY); | ||
let d = digest.finalize(); | ||
|
||
let n = base64::encode_config_slice(&d, base64::STANDARD, buf); | ||
&buf[..n] | ||
} | ||
|
||
/// Check if a request is a websocket upgrade request. | ||
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool { | ||
header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade") | ||
&& header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket") | ||
} | ||
|
||
/// Check if there is a header of the given name containing the wanted value. | ||
fn header_contains_value(headers: &hyper::HeaderMap, header: hyper::header::HeaderName, value: &[u8]) -> bool { | ||
pub fn trim(x: &[u8]) -> &[u8] { | ||
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) { | ||
Some(i) => i, | ||
None => return &[], | ||
}; | ||
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap(); | ||
&x[from..=to] | ||
} | ||
|
||
for header in headers.get_all(header) { | ||
if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) { | ||
return true; | ||
} | ||
} | ||
false | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,13 @@ | |
//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 | ||
|
||
pub mod client; | ||
#[cfg(feature = "http")] | ||
pub mod http; | ||
pub mod server; | ||
|
||
use crate::extension::{Extension, Param}; | ||
use bytes::BytesMut; | ||
use sha1::{Digest, Sha1}; | ||
use std::{fmt, io, str}; | ||
|
||
pub use client::{Client, ServerResponse}; | ||
|
@@ -105,7 +108,15 @@ where | |
bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") | ||
} | ||
|
||
while let Some(e) = iter.next() { | ||
append_extension_header_value(iter, bytes) | ||
} | ||
|
||
// Write the extension header value to the given buffer. | ||
fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable<I>, bytes: &mut BytesMut) | ||
where | ||
I: Iterator<Item = &'a Box<dyn Extension + Send>>, | ||
{ | ||
while let Some(e) = extensions_iter.next() { | ||
bytes.extend_from_slice(e.name().as_bytes()); | ||
for p in e.params() { | ||
bytes.extend_from_slice(b"; "); | ||
|
@@ -115,12 +126,27 @@ where | |
bytes.extend_from_slice(v.as_bytes()) | ||
} | ||
} | ||
if iter.peek().is_some() { | ||
if extensions_iter.peek().is_some() { | ||
bytes.extend_from_slice(b", ") | ||
} | ||
} | ||
} | ||
|
||
// This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via | ||
// the `Sec-WebSocket-Key` header during a websocket handshake, and a 32 byte output buffer, and | ||
// writes the response that's expected to be handed back in the response header `Sec-WebSocket-Accept`. | ||
// | ||
// See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this. | ||
fn generate_accept_key<'k>(key_base64: &[u8; 24], output_buf: &'k mut [u8; 32]) -> &'k [u8] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the pointer! I am now returning a 28 (not 32) byte array rather than taking in a buffer. Unless I'm mistaken, base64 encoding a 160bit hash will always use exactly 28 bytes! |
||
let mut digest = Sha1::new(); | ||
digest.update(key_base64); | ||
digest.update(KEY); | ||
let d = digest.finalize(); | ||
|
||
let n = base64::encode_config_slice(&d, base64::STANDARD, output_buf); | ||
&output_buf[..n] | ||
} | ||
|
||
/// Enumeration of possible handshake errors. | ||
#[non_exhaustive] | ||
#[derive(Debug)] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have a feature, make sure CI checks the code behind it