Skip to content

Commit

Permalink
client: set authorization header from the URL (#1384)
Browse files Browse the repository at this point in the history
* client: set basic auth based on the URL

* address grumbles

* Update client/transport/src/ws/mod.rs

* Update client/http-client/src/transport.rs
  • Loading branch information
niklasad1 authored May 31, 2024
1 parent d36d3a8 commit a3307b8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 19 deletions.
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ publish = true

[dependencies]
async-trait = "0.1"
base64 = { version = "0.22", default-features = false, features = ["alloc"] }
hyper = { version = "1.3", features = ["client", "http1", "http2"] }
hyper-rustls = { version = "0.27.1", default-features = false, features = ["http1", "http2", "tls12", "logging", "ring"], optional = true }
hyper-util = { version = "0.1.1", features = ["client", "client-legacy", "tokio", "http1", "http2"] }
Expand Down
13 changes: 13 additions & 0 deletions client/http-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// that we need to be guaranteed that hyper doesn't re-use an existing connection if we ever reset
// the JSON-RPC request id to a value that might have already been used.

use base64::Engine;
use hyper::body::Bytes;
use hyper::http::{HeaderMap, HeaderValue};
use hyper_util::client::legacy::connect::HttpConnector;
Expand Down Expand Up @@ -204,6 +205,7 @@ impl<L> HttpTransportClientBuilder<L> {
tcp_no_delay,
} = self;
let mut url = Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {e}")))?;

if url.host_str().is_none() {
return Err(Error::Url("Invalid host".into()));
}
Expand Down Expand Up @@ -258,6 +260,17 @@ impl<L> HttpTransportClientBuilder<L> {
}
}

if let Some(pwd) = url.password() {
if !cached_headers.contains_key(hyper::header::AUTHORIZATION) {
let digest = base64::engine::general_purpose::STANDARD.encode(format!("{}:{pwd}", url.username()));
cached_headers.insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {digest}"))
.map_err(|_| Error::Url("Header value `authorization basic user:pwd` invalid".into()))?,
);
}
}

Ok(HttpTransportClient {
target: url.as_str().to_owned(),
client: service_builder.service(client),
Expand Down
2 changes: 2 additions & 0 deletions client/transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tokio-util = { version = "0.7", features = ["compat"], optional = true }
tokio = { version = "1.16", features = ["net", "time", "macros"], optional = true }
pin-project = { version = "1", optional = true }
url = { version = "2.4.0", optional = true }
base64 = { version = "0.22", default-features = false, features = ["alloc"], optional = true }

# tls
tokio-rustls = { version = "0.26", default-features = false, optional = true, features = ["logging", "tls12", "ring"] }
Expand All @@ -43,6 +44,7 @@ gloo-net = { version = "0.5.0", default-features = false, features = ["json", "w
tls = ["tokio-rustls", "rustls-pki-types", "rustls-platform-verifier", "rustls"]

ws = [
"base64",
"futures-util",
"http",
"tokio",
Expand Down
82 changes: 63 additions & 19 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::io;
use std::net::SocketAddr;
use std::time::Duration;

use base64::Engine;
use futures_util::io::{BufReader, BufWriter};
use jsonrpsee_core::client::{MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
Expand Down Expand Up @@ -438,8 +439,22 @@ impl WsTransportClientBuilder {
&target.path_and_query,
);

let headers: Vec<_> =
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect();
let headers: Vec<_> = match &target.basic_auth {
Some(basic_auth) if !self.headers.contains_key(http::header::AUTHORIZATION) => {
let it1 =
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() });
let it2 = std::iter::once(Header {
name: http::header::AUTHORIZATION.as_str(),
value: basic_auth.as_bytes(),
});

it1.chain(it2).collect()
}
_ => {
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect()
}
};

client.set_headers(&headers);

// Perform the initial handshake.
Expand Down Expand Up @@ -531,8 +546,8 @@ impl From<soketto::connection::Error> for WsError {
}

/// Represents a verified remote WebSocket address.
#[derive(Debug, Clone)]
pub struct Target {
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Target {
/// Socket addresses resolved the host name.
sockaddrs: Vec<SocketAddr>,
/// The host name (domain or IP address).
Expand All @@ -543,6 +558,8 @@ pub struct Target {
_mode: Mode,
/// The path and query parts from an URL.
path_and_query: String,
/// Optional <username:password> from an URL.
basic_auth: Option<HeaderValue>,
}

impl TryFrom<url::Url> for Target {
Expand All @@ -569,14 +586,20 @@ impl TryFrom<url::Url> for Target {
path_and_query.push_str(query);
}

let basic_auth = if let Some(pwd) = url.password() {
let digest = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", url.username(), pwd));
let val = HeaderValue::from_str(&format!("Basic {digest}"))
.map_err(|_| WsHandshakeError::Url("Header value `authorization basic user:pwd` invalid".into()))?;

Some(val)
} else {
None
};

let host_header = if let Some(port) = url.port() { format!("{host}:{port}") } else { host.to_string() };

let sockaddrs = url.socket_addrs(|| None).map_err(WsHandshakeError::ResolutionFailed)?;
Ok(Self {
sockaddrs,
host,
host_header: url.authority().to_string(),
_mode,
path_and_query: path_and_query.to_string(),
})
Ok(Self { sockaddrs, host, host_header, _mode, path_and_query: path_and_query.to_string(), basic_auth })
}
}

Expand All @@ -593,13 +616,23 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsCo

#[cfg(test)]
mod tests {
use http::HeaderValue;

use super::{Mode, Target, Url, WsHandshakeError};

fn assert_ws_target(target: Target, host: &str, host_header: &str, mode: Mode, path_and_query: &str) {
fn assert_ws_target(
target: Target,
host: &str,
host_header: &str,
mode: Mode,
path_and_query: &str,
basic_auth: Option<HeaderValue>,
) {
assert_eq!(&target.host, host);
assert_eq!(&target.host_header, host_header);
assert_eq!(target._mode, mode);
assert_eq!(&target.path_and_query, path_and_query);
assert_eq!(target.basic_auth, basic_auth);
}

fn parse_target(uri: &str) -> Result<Target, WsHandshakeError> {
Expand All @@ -609,14 +642,14 @@ mod tests {
#[test]
fn ws_works_with_port() {
let target = parse_target("ws://127.0.0.1:9933").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/");
assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/", None);
}

#[cfg(feature = "tls")]
#[test]
fn wss_works_with_port() {
let target = parse_target("wss://kusama-rpc.polkadot.io:9999").unwrap();
assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:9999", Mode::Tls, "/");
assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:9999", Mode::Tls, "/", None);
}

#[cfg(not(feature = "tls"))]
Expand All @@ -643,31 +676,42 @@ mod tests {
#[test]
fn url_with_path_works() {
let target = parse_target("ws://127.0.0.1/my-special-path").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my-special-path");
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my-special-path", None);
}

#[test]
fn url_with_query_works() {
let target = parse_target("ws://127.0.0.1/my?name1=value1&name2=value2").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my?name1=value1&name2=value2");
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my?name1=value1&name2=value2", None);
}

#[test]
fn url_with_fragment_is_ignored() {
let target = parse_target("ws://127.0.0.1:/my.htm#ignore").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my.htm");
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my.htm", None);
}

#[cfg(feature = "tls")]
#[test]
fn wss_default_port_is_omitted() {
let target = parse_target("wss://127.0.0.1:443").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Tls, "/");
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Tls, "/", None);
}

#[test]
fn ws_default_port_is_omitted() {
let target = parse_target("ws://127.0.0.1:80").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/");
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", None);
}

#[test]
fn ws_with_username_and_password() {
use base64::Engine;

let target = parse_target("ws://user:pwd@127.0.0.1").unwrap();
let digest = base64::engine::general_purpose::STANDARD.encode("user:pwd");
let basic_auth = HeaderValue::from_str(&format!("Basic {digest}")).unwrap();

assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", Some(basic_auth));
}
}

0 comments on commit a3307b8

Please sign in to comment.