Skip to content

Commit

Permalink
Initial import of PR seanmonstar#1623
Browse files Browse the repository at this point in the history
  • Loading branch information
get9 committed May 15, 2023
1 parent 7047669 commit 560302d
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 1 deletion.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ features = [
wasm-bindgen = { version = "0.2.68", features = ["serde-serialize"] }
wasm-bindgen-test = "0.3"

# to test unix domain sockets
[target.'cfg(unix)'.dev-dependencies]
hyperlocal = "0.8"
tempfile = "3.3"

[[example]]
name = "blocking"
path = "examples/blocking.rs"
Expand Down
32 changes: 32 additions & 0 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
#[cfg(unix)]
use std::path::Path;

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
Expand Down Expand Up @@ -163,6 +165,15 @@ impl Connector {
self.verbose.0 = enabled;
}

#[cfg(unix)]
async fn connect_unix_socket<P: AsRef<Path>>(&self, socket: P) -> Result<Conn, BoxError> {
let tcp_stream = unix_socket_conn::connect(socket).await?;
Ok(Conn {
inner: self.verbose.wrap(tcp_stream),
is_proxy: false, // defaults to false to have the same behavior as curl's --unix-socket
})
}

#[cfg(feature = "socks")]
async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
let dns = match proxy {
Expand All @@ -175,6 +186,10 @@ impl Connector {
ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
unreachable!("connect_socks is only called for socks proxies");
}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => {
unreachable!("connect_socks is only called for socks proxies");
}
};

match &self.inner {
Expand Down Expand Up @@ -306,6 +321,8 @@ impl Connector {
ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
#[cfg(unix)]
ProxyScheme::UnixSocket { socket } => return self.connect_unix_socket(socket).await,
};

#[cfg(feature = "__tls")]
Expand Down Expand Up @@ -829,6 +846,21 @@ mod socks {
}
}

#[cfg(unix)]
mod unix_socket_conn {
use std::os::unix::io::OwnedFd;
use std::path::Path;
use tokio::net::{TcpStream, UnixStream};
use crate::error::BoxError;

pub async fn connect<P: AsRef<Path>>(socket: P) -> Result<TcpStream, BoxError> {
let target_stream = UnixStream::connect(&socket).await?;
let owned_fd: OwnedFd = target_stream.into_std()?.into();
let stream = std::net::TcpStream::from(owned_fd);
Ok(TcpStream::from_std(stream)?)
}
}

mod verbose {
use hyper::client::connect::{Connected, Connection};
use std::cmp::min;
Expand Down
2 changes: 1 addition & 1 deletion src/into_url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub trait IntoUrlSealed {

impl IntoUrlSealed for Url {
fn into_url(self) -> crate::Result<Url> {
if self.has_host() {
if self.scheme() == "unix" || self.has_host() {
Ok(self)
} else {
Err(crate::error::url_bad_scheme(self))
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
//! export https_proxy=socks5://127.0.0.1:1086
//! ```
//!
//! You can aso configure a proxy to send requests through unix domain sockets (see [Proxy](Proxy) for details).
//!
//! ## TLS
//!
//! By default, a `Client` will make use of system-native transport layer
Expand Down
62 changes: 62 additions & 0 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::net::IpAddr;
#[cfg(unix)]
use std::path::PathBuf;
#[cfg(target_os = "windows")]
use winreg::enums::HKEY_CURRENT_USER;
#[cfg(target_os = "windows")]
Expand Down Expand Up @@ -50,6 +52,16 @@ use winreg::RegKey;
/// # Ok(())
/// # }
/// ```
///
/// On unix, it is also possible to send request to a unix socket via url or [Proxy::unix]:
/// ```rust
/// # fn run() -> Result<(), Box<std::error::Error>> {
/// let proxy = reqwest::Proxy::all("unix:///run/snapd.socket")?;
/// // equivalent to:
/// let proxy = reqwest::Proxy::unix("/run/snapd.socket");
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct Proxy {
intercept: Intercept,
Expand Down Expand Up @@ -99,6 +111,10 @@ pub enum ProxyScheme {
auth: Option<(String, String)>,
remote_dns: bool,
},
#[cfg(unix)]
UnixSocket {
socket: PathBuf,
},
}

impl ProxyScheme {
Expand All @@ -107,6 +123,8 @@ impl ProxyScheme {
ProxyScheme::Http { auth, .. } | ProxyScheme::Https { auth, .. } => auth.as_ref(),
#[cfg(feature = "socks")]
_ => None,
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => None,
}
}
}
Expand Down Expand Up @@ -235,6 +253,26 @@ impl Proxy {
)))
}

/// Proxy **all** traffic to the passed unix domain socket.
///
/// # Example
///
/// ```
/// # extern crate reqwest;
/// # fn run() -> Result<(), Box<std::error::Error>> {
/// let client = reqwest::Client::builder()
/// .proxy(reqwest::Proxy::unix("/run/snapd.socket"))
/// .build()?;
/// # Ok(())
/// # }
/// # fn main() {}
/// ```
pub fn unix<Path: Into<PathBuf>>(socket_path: Path) -> Proxy {
Proxy::new(Intercept::All(
ProxyScheme::unix_socket(socket_path),
))
}

/// Provide a custom function to determine what traffic to proxy to where.
///
/// # Example
Expand Down Expand Up @@ -577,6 +615,14 @@ impl ProxyScheme {
})
}

/// Proxy traffic via the specified URL over HTTPS
#[cfg(unix)]
fn unix_socket<Path: Into<PathBuf>>(path: Path) -> Self {
ProxyScheme::UnixSocket {
socket: path.into(),
}
}

/// Use a username and password when connecting to the proxy server
fn with_basic_auth<T: Into<String>, U: Into<String>>(
mut self,
Expand All @@ -601,6 +647,8 @@ impl ProxyScheme {
ProxyScheme::Socks5 { ref mut auth, .. } => {
*auth = Some((username.into(), password.into()));
}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => (),
}
}

Expand All @@ -618,6 +666,8 @@ impl ProxyScheme {
}
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => {}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => {}
}

self
Expand Down Expand Up @@ -652,6 +702,8 @@ impl ProxyScheme {
"socks5" => Self::socks5(to_addr()?)?,
#[cfg(feature = "socks")]
"socks5h" => Self::socks5h(to_addr()?)?,
#[cfg(unix)]
"unix" => Self::unix_socket(url.path()),
_ => return Err(crate::error::builder("unknown proxy scheme")),
};

Expand All @@ -671,6 +723,8 @@ impl ProxyScheme {
ProxyScheme::Https { .. } => "https",
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => "socks5",
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => "unix",
}
}

Expand All @@ -681,6 +735,8 @@ impl ProxyScheme {
ProxyScheme::Https { host, .. } => host.as_str(),
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => panic!("socks5"),
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => panic!("unix"),
}
}
}
Expand All @@ -699,6 +755,10 @@ impl fmt::Debug for ProxyScheme {
let h = if *remote_dns { "h" } else { "" };
write!(f, "socks5{}://{}", h, addr)
}
#[cfg(unix)]
ProxyScheme::UnixSocket { socket } => {
write!(f, "unix://{}", socket.display())
}
}
}
}
Expand Down Expand Up @@ -991,6 +1051,8 @@ mod tests {
let (scheme, host) = match p.intercept(&url(s)).unwrap() {
ProxyScheme::Http { host, .. } => ("http", host),
ProxyScheme::Https { host, .. } => ("https", host),
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => panic!("intercepted as unix"),
#[cfg(feature = "socks")]
_ => panic!("intercepted as socks"),
};
Expand Down
32 changes: 32 additions & 0 deletions tests/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod support;
use support::*;

use std::env;
#[cfg(unix)]
use std::path::PathBuf;

#[tokio::test]
async fn http_proxy() {
Expand Down Expand Up @@ -221,3 +223,33 @@ async fn http_over_http() {
assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}

#[cfg(unix)]
#[tokio::test]
async fn http_over_unix() {
let proxy = tempfile::TempPath::from_path("/tmp/reqwest.socket");
let url = "http://hyper.rs/prox";

let _server = unix_server::http(PathBuf::from(&proxy), move |req| {
assert_eq!(req.method(), "GET");
assert_eq!(req.uri(), "/prox");
assert_eq!(req.headers()["host"], "hyper.rs");

async { http::Response::default() }
});


let proxy_url = format!("unix://{}", proxy.display());
let res = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(&proxy_url).unwrap())
.connection_verbose(true)
.build()
.unwrap()
.get(url)
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}
2 changes: 2 additions & 0 deletions tests/support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod server;
#[cfg(unix)]
pub mod unix_server;

// TODO: remove once done converting to new support server?
#[allow(unused)]
Expand Down
85 changes: 85 additions & 0 deletions tests/support/unix_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#![cfg(not(target_arch = "wasm32"))]
use std::convert::Infallible;
use std::future::Future;
use std::path::PathBuf;
use std::sync::mpsc as std_mpsc;
use std::thread;
use std::time::Duration;

use tokio::sync::oneshot;

pub use http::Response;
use hyperlocal::UnixServerExt;
use tokio::runtime;

pub struct Server {
panic_rx: std_mpsc::Receiver<()>,
shutdown_tx: Option<oneshot::Sender<()>>,
}

impl Drop for Server {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}

if !::std::thread::panicking() {
self.panic_rx
.recv_timeout(Duration::from_secs(3))
.expect("test server should not panic");
}
}
}

#[allow(unused)]
pub fn http<F, Fut>(socket: PathBuf, func: F) -> Server
where
F: Fn(http::Request<hyper::Body>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = http::Response<hyper::Body>> + Send + 'static,
{
//Spawn new runtime in thread to prevent reactor execution context conflict
thread::spawn(move || {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("new rt");
let srv = rt.block_on(async move {
hyper::Server::bind_unix(socket).unwrap().serve(hyper::service::make_service_fn(
move |_| {
let func = func.clone();
async move {
Ok::<_, Infallible>(hyper::service::service_fn(move |req| {
let fut = func(req);
async move { Ok::<_, Infallible>(fut.await) }
}))
}
},
))
});

let (shutdown_tx, shutdown_rx) = oneshot::channel();
let srv = srv.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
});

let (panic_tx, panic_rx) = std_mpsc::channel();
let tname = format!(
"test({})-support-server",
thread::current().name().unwrap_or("<unknown>")
);
thread::Builder::new()
.name(tname)
.spawn(move || {
rt.block_on(srv).unwrap();
let _ = panic_tx.send(());
})
.expect("thread spawn");

Server {
panic_rx,
shutdown_tx: Some(shutdown_tx),
}
})
.join()
.unwrap()
}

0 comments on commit 560302d

Please sign in to comment.