diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 5d9b0edb6d..fb88399ec5 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -76,7 +76,7 @@ serde_path_to_error = { version = "0.1.8", optional = true } serde_urlencoded = { version = "0.7", optional = true } sha1 = { version = "0.10", optional = true } tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true } -tokio-tungstenite = { version = "0.24.0", optional = true } +tokio-tungstenite = { version = "0.26.0", optional = true } tracing = { version = "0.1", default-features = false, optional = true } [dependencies.tower-http] @@ -127,7 +127,7 @@ serde_json = { version = "1.0", features = ["raw_value"] } time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" -tokio-tungstenite = "0.24.0" +tokio-tungstenite = "0.26.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json"] } uuid = { version = "1.0", features = ["serde", "v4"] } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index fa06d249ad..3d96d89888 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -553,16 +553,131 @@ impl Sink for WebSocket { } } +/// UTF-8 wrapper for [Bytes]. +/// +/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Utf8Bytes(ts::Utf8Bytes); + +impl Utf8Bytes { + /// Creates from a static str. + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(ts::Utf8Bytes::from_static(str)) + } + + /// Returns as a string slice. + #[inline] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + fn into_tungstenite(self) -> ts::Utf8Bytes { + self.0 + } +} + +impl std::ops::Deref for Utf8Bytes { + type Target = str; + + /// ``` + /// /// Example fn that takes a str slice + /// fn a(s: &str) {} + /// + /// let data = axum::extract::ws::Utf8Bytes::from_static("foo123"); + /// + /// // auto-deref as arg + /// a(&data); + /// + /// // deref to str methods + /// assert_eq!(data.len(), 6); + /// ``` + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl std::fmt::Display for Utf8Bytes { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl TryFrom for Utf8Bytes { + type Error = std::str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + Ok(Self(bytes.try_into()?)) + } +} + +impl TryFrom> for Utf8Bytes { + type Error = std::str::Utf8Error; + + #[inline] + fn try_from(v: Vec) -> Result { + Ok(Self(v.try_into()?)) + } +} + +impl From for Utf8Bytes { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Bytes { + #[inline] + fn from(s: &str) -> Self { + Self(s.into()) + } +} + +impl From<&String> for Utf8Bytes { + #[inline] + fn from(s: &String) -> Self { + Self(s.into()) + } +} + +impl From for Bytes { + #[inline] + fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self { + bytes.into() + } +} + +impl PartialEq for Utf8Bytes +where + for<'a> &'a str: PartialEq, +{ + /// ``` + /// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123"); + /// assert_eq!(payload, "foo123"); + /// assert_eq!(payload, "foo123".to_string()); + /// assert_eq!(payload, &"foo123".to_string()); + /// assert_eq!(payload, std::borrow::Cow::from("foo123")); + /// ``` + #[inline] + fn eq(&self, other: &T) -> bool { + self.as_str() == *other + } +} + /// Status code used to indicate why an endpoint is closing the WebSocket connection. pub type CloseCode = u16; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] -pub struct CloseFrame<'t> { +pub struct CloseFrame { /// The reason as a code. pub code: CloseCode, /// The reason as text string. - pub reason: Cow<'t, str>, + pub reason: Utf8Bytes, } /// A WebSocket message. @@ -591,16 +706,16 @@ pub struct CloseFrame<'t> { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Utf8Bytes), /// A binary WebSocket message - Binary(Vec), + Binary(Bytes), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes. /// /// Ping messages will be automatically responded to by the server, so you do not have to worry /// about dealing with them yourself. - Ping(Vec), + Ping(Bytes), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes. @@ -608,7 +723,7 @@ pub enum Message { /// Pong messages will be automatically sent to the client if a ping message is received, so /// you do not have to worry about constructing them yourself unless you want to implement a /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). - Pong(Vec), + Pong(Bytes), /// A close message with the optional close frame. /// /// You may "uncleanly" close a WebSocket connection at any time @@ -628,19 +743,19 @@ pub enum Message { /// Since no further messages will be received, /// you may either do nothing /// or explicitly drop the connection. - Close(Option>), + Close(Option), } impl Message { fn into_tungstenite(self) -> ts::Message { match self { - Self::Text(text) => ts::Message::Text(text), + Self::Text(text) => ts::Message::Text(text.into_tungstenite()), Self::Binary(binary) => ts::Message::Binary(binary), Self::Ping(ping) => ts::Message::Ping(ping), Self::Pong(pong) => ts::Message::Pong(pong), Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame { code: ts::protocol::frame::coding::CloseCode::from(close.code), - reason: close.reason, + reason: close.reason.into_tungstenite(), })), Self::Close(None) => ts::Message::Close(None), } @@ -648,13 +763,13 @@ impl Message { fn from_tungstenite(message: ts::Message) -> Option { match message { - ts::Message::Text(text) => Some(Self::Text(text)), + ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))), ts::Message::Binary(binary) => Some(Self::Binary(binary)), ts::Message::Ping(ping) => Some(Self::Ping(ping)), ts::Message::Pong(pong) => Some(Self::Pong(pong)), ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame { code: close.code.into(), - reason: close.reason, + reason: Utf8Bytes(close.reason), }))), ts::Message::Close(None) => Some(Self::Close(None)), // we can ignore `Frame` frames as recommended by the tungstenite maintainers @@ -664,24 +779,24 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Vec { + pub fn into_data(self) -> Bytes { match self { - Self::Text(string) => string.into_bytes(), + Self::Text(string) => Bytes::from(string), Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data, - Self::Close(None) => Vec::new(), - Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), + Self::Close(None) => Bytes::new(), + Self::Close(Some(frame)) => Bytes::from(frame.reason), } } - /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes. + pub fn into_text(self) -> Result { match self { Self::Text(string) => Ok(string), - Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data) - .map_err(|err| err.utf8_error()) - .map_err(Error::new)?), - Self::Close(None) => Ok(String::new()), - Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => { + Ok(Utf8Bytes::try_from(data).map_err(Error::new)?) + } + Self::Close(None) => Ok(Utf8Bytes::default()), + Self::Close(Some(frame)) => Ok(frame.reason), } } @@ -689,7 +804,7 @@ impl Message { /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str, Error> { match *self { - Self::Text(ref string) => Ok(string), + Self::Text(ref string) => Ok(string.as_str()), Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { Ok(std::str::from_utf8(data).map_err(Error::new)?) } @@ -697,11 +812,27 @@ impl Message { Self::Close(Some(ref frame)) => Ok(&frame.reason), } } + + /// Create a new text WebSocket message from a stringable. + pub fn text(string: S) -> Message + where + S: Into, + { + Message::Text(string.into()) + } + + /// Create a new binary WebSocket message by converting to `Bytes`. + pub fn binary(bin: B) -> Message + where + B: Into, + { + Message::Binary(bin.into()) + } } impl From for Message { fn from(string: String) -> Self { - Message::Text(string) + Message::Text(string.into()) } } @@ -713,19 +844,19 @@ impl<'s> From<&'s str> for Message { impl<'b> From<&'b [u8]> for Message { fn from(data: &'b [u8]) -> Self { - Message::Binary(data.into()) + Message::Binary(Bytes::copy_from_slice(data)) } } impl From> for Message { fn from(data: Vec) -> Self { - Message::Binary(data) + Message::Binary(data.into()) } } impl From for Vec { fn from(msg: Message) -> Self { - msg.into_data() + msg.into_data().to_vec() } } @@ -1026,19 +1157,19 @@ mod tests { } async fn test_echo_app(mut socket: WebSocketStream) { - let input = tungstenite::Message::Text("foobar".to_owned()); + let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar")); socket.send(input.clone()).await.unwrap(); let output = socket.next().await.unwrap().unwrap(); assert_eq!(input, output); socket - .send(tungstenite::Message::Ping("ping".to_owned().into_bytes())) + .send(tungstenite::Message::Ping(Bytes::from_static(b"ping"))) .await .unwrap(); let output = socket.next().await.unwrap().unwrap(); assert_eq!( output, - tungstenite::Message::Pong("ping".to_owned().into_bytes()) + tungstenite::Message::Pong(Bytes::from_static(b"ping")) ); } } diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 77baada1b5..1c07301ed8 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -8,7 +8,7 @@ use axum::{ extract::{ - ws::{Message, WebSocket, WebSocketUpgrade}, + ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, State, }, response::{Html, IntoResponse}, @@ -79,7 +79,7 @@ async fn websocket(stream: WebSocket, state: Arc) { while let Some(Ok(message)) = receiver.next().await { if let Message::Text(name) = message { // If username that is sent by client is not taken, fill username string. - check_username(&state, &mut username, &name); + check_username(&state, &mut username, name.as_str()); // If not empty we want to quit the loop else we want to quit function. if !username.is_empty() { @@ -87,7 +87,9 @@ async fn websocket(stream: WebSocket, state: Arc) { } else { // Only send our client that username is taken. let _ = sender - .send(Message::Text(String::from("Username already taken."))) + .send(Message::Text(Utf8Bytes::from_static( + "Username already taken.", + ))) .await; return; @@ -109,7 +111,7 @@ async fn websocket(stream: WebSocket, state: Arc) { let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { // In any websocket error, break loop. - if sender.send(Message::Text(msg)).await.is_err() { + if sender.send(Message::text(msg)).await.is_err() { break; } } diff --git a/examples/testing-websockets/Cargo.toml b/examples/testing-websockets/Cargo.toml index 31ed2601f0..8942f9e2a0 100644 --- a/examples/testing-websockets/Cargo.toml +++ b/examples/testing-websockets/Cargo.toml @@ -8,4 +8,4 @@ publish = false axum = { path = "../../axum", features = ["ws"] } futures = "0.3" tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.24" +tokio-tungstenite = "0.26" diff --git a/examples/testing-websockets/src/main.rs b/examples/testing-websockets/src/main.rs index 384be35d53..7a0be11ce4 100644 --- a/examples/testing-websockets/src/main.rs +++ b/examples/testing-websockets/src/main.rs @@ -48,7 +48,7 @@ async fn integration_testable_handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.recv().await { if let Message::Text(msg) = msg { if socket - .send(Message::Text(format!("You said: {msg}"))) + .send(Message::Text(format!("You said: {msg}").into())) .await .is_err() { @@ -79,7 +79,7 @@ where while let Some(Ok(msg)) = read.next().await { if let Message::Text(msg) = msg { if write - .send(Message::Text(format!("You said: {msg}"))) + .send(Message::Text(format!("You said: {msg}").into())) .await .is_err() { @@ -123,7 +123,7 @@ mod tests { other => panic!("expected a text message but got {other:?}"), }; - assert_eq!(msg, "You said: foo"); + assert_eq!(msg.as_str(), "You said: foo"); } // We can unit test the other handler by creating channels to read and write from. @@ -136,16 +136,13 @@ mod tests { tokio::spawn(unit_testable_handle_socket(socket_write, socket_read)); - test_tx - .send(Ok(Message::Text("foo".to_owned()))) - .await - .unwrap(); + test_tx.send(Ok(Message::Text("foo".into()))).await.unwrap(); let msg = match test_rx.next().await.unwrap() { Message::Text(msg) => msg, other => panic!("expected a text message but got {other:?}"), }; - assert_eq!(msg, "You said: foo"); + assert_eq!(msg.as_str(), "You said: foo"); } } diff --git a/examples/websockets-http2/src/main.rs b/examples/websockets-http2/src/main.rs index dbc682c4d9..f3f33aacac 100644 --- a/examples/websockets-http2/src/main.rs +++ b/examples/websockets-http2/src/main.rs @@ -75,7 +75,7 @@ async fn ws_handler( res = ws.recv() => { match res { Some(Ok(ws::Message::Text(s))) => { - let _ = sender.send(s); + let _ = sender.send(s.to_string()); } Some(Ok(_)) => {} Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"), @@ -85,7 +85,7 @@ async fn ws_handler( // Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe. res = receiver.recv() => { match res { - Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg)).await { + Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg.into())).await { tracing::debug!("client disconnected abruptly: {e}"); } Err(_) => continue, diff --git a/examples/websockets/Cargo.toml b/examples/websockets/Cargo.toml index 541d82805a..0c1eb36a5b 100644 --- a/examples/websockets/Cargo.toml +++ b/examples/websockets/Cargo.toml @@ -11,7 +11,7 @@ futures = "0.3" futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } headers = "0.4" tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.24.0" +tokio-tungstenite = "0.26.0" tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/websockets/src/client.rs b/examples/websockets/src/client.rs index 5d0a670672..a30341315b 100644 --- a/examples/websockets/src/client.rs +++ b/examples/websockets/src/client.rs @@ -12,9 +12,9 @@ use futures_util::stream::FuturesUnordered; use futures_util::{SinkExt, StreamExt}; -use std::borrow::Cow; use std::ops::ControlFlow; use std::time::Instant; +use tokio_tungstenite::tungstenite::Utf8Bytes; // we will use tungstenite for websocket client impl (same library as what axum is using) use tokio_tungstenite::{ @@ -65,7 +65,9 @@ async fn spawn_client(who: usize) { //we can ping the server for start sender - .send(Message::Ping("Hello, Server!".into())) + .send(Message::Ping(axum::body::Bytes::from_static( + b"Hello, Server!", + ))) .await .expect("Can not send!"); @@ -74,7 +76,7 @@ async fn spawn_client(who: usize) { for i in 1..30 { // In any websocket error, break loop. if sender - .send(Message::Text(format!("Message number {i}..."))) + .send(Message::Text(format!("Message number {i}...").into())) .await .is_err() { @@ -90,7 +92,7 @@ async fn spawn_client(who: usize) { if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: CloseCode::Normal, - reason: Cow::from("Goodbye"), + reason: Utf8Bytes::from_static("Goodbye"), }))) .await { diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 7c4a9801af..fbf0198617 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -17,14 +17,14 @@ //! ``` use axum::{ - extract::ws::{Message, WebSocket, WebSocketUpgrade}, + body::Bytes, + extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, response::IntoResponse, routing::any, Router, }; use axum_extra::TypedHeader; -use std::borrow::Cow; use std::ops::ControlFlow; use std::{net::SocketAddr, path::PathBuf}; use tower_http::{ @@ -101,7 +101,11 @@ async fn ws_handler( /// Actual websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { // send a ping (unsupported by some browsers) just to kick things off and get a response - if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() { + if socket + .send(Message::Ping(Bytes::from_static(&[1, 2, 3]))) + .await + .is_ok() + { println!("Pinged {who}..."); } else { println!("Could not send ping {who}!"); @@ -131,7 +135,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { // connecting to server and receiving their greetings. for i in 1..5 { if socket - .send(Message::Text(format!("Hi {i} times!"))) + .send(Message::Text(format!("Hi {i} times!").into())) .await .is_err() { @@ -151,7 +155,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { for i in 0..n_msg { // In case of any websocket error, we exit. if sender - .send(Message::Text(format!("Server message {i} ..."))) + .send(Message::Text(format!("Server message {i} ...").into())) .await .is_err() { @@ -165,7 +169,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: axum::extract::ws::close_code::NORMAL, - reason: Cow::from("Goodbye"), + reason: Utf8Bytes::from_static("Goodbye"), }))) .await {