Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(connector): Optimize performance of switching TLS connector #406

Merged
merged 5 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/client/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ impl ClientBuilder {
http.set_connect_timeout(config.connect_timeout);

let tls = BoringTlsConnector::new(config.tls_config)?;
let mut builder = ConnectorBuilder::new(http, tls, config.nodelay, config.tls_info);
builder.set_timeout(config.connect_timeout);
builder.set_verbose(config.connection_verbose);
builder.set_keepalive(config.tcp_keepalive);
builder.build(config.connector_layers)
ConnectorBuilder::new(http, tls, config.nodelay, config.tls_info)
.timeout(config.connect_timeout)
.keepalive(config.tcp_keepalive)
.verbose(config.connection_verbose)
.build(config.connector_layers)
};

Ok(Client {
Expand Down
97 changes: 53 additions & 44 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::util::client::connect::{Connected, Connection};
use crate::util::client::Dst;
use crate::util::rt::TokioIo;
use crate::util::{self, into_uri};
use antidote::RwLock;
use http::uri::Scheme;
use hyper2::rt::{Read, ReadBufCursor, Write};
use pin_project_lite::pin_project;
Expand All @@ -17,7 +16,6 @@ use tower_service::Service;
use std::future::Future;
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

Expand Down Expand Up @@ -47,35 +45,37 @@ impl ConnectorBuilder {
// we have no user-provided layers, only use concrete types
let base_service = ConnectorService {
http: self.http,
tls: InnerTLS::Simple(self.tls),
tls: self.tls,
verbose: self.verbose,
nodelay: self.nodelay,
tls_info: self.tls_info,
timeout: self.timeout,
keepalive: None,
};
return Connector::Simple(base_service);
}

let inner_tls = InnerTLS::WithSharedState(Arc::new(RwLock::new(self.tls)));
let base_service = ConnectorService {
http: self.http,
tls: inner_tls.clone(),
tls: self.tls,
verbose: self.verbose,
nodelay: self.nodelay,
tls_info: self.tls_info,
timeout: None,
keepalive: None,
};

// otherwise we have user provided layers
// so we need type erasure all the way through
// as well as mapping the unnameable type of the layers back to Dst for the inner service
let unnameable_service = ServiceBuilder::new()
.layer(MapRequestLayer::new(|request: Unnameable| request.0))
.service(base_service);
let mut service = BoxCloneSyncService::new(unnameable_service);
for layer in layers {
service = ServiceBuilder::new().layer(layer).service(service);
}
let service = layers.iter().fold(
BoxCloneSyncService::new(
ServiceBuilder::new()
.layer(MapRequestLayer::new(|request: Unnameable| request.0))
.service(base_service.clone()),
),
|service, layer| ServiceBuilder::new().layer(layer).service(service),
);

// now we handle the concrete stuff - any `connect_timeout`,
// plus a final map_err layer we can use to cast default tower layer
Expand All @@ -89,7 +89,11 @@ impl ConnectorBuilder {
.map_err(cast_to_internal_error)
.service(service);
let service = BoxCloneSyncService::new(service);
Connector::WithLayers { inner_tls, service }
Connector::WithLayers {
layers,
base_service,
service,
}
}
None => {
// no timeout, but still map err
Expand All @@ -100,7 +104,11 @@ impl ConnectorBuilder {
.map_err(cast_to_internal_error)
.service(service);
let service = BoxCloneSyncService::new(service);
Connector::WithLayers { inner_tls, service }
Connector::WithLayers {
layers,
base_service,
service,
}
}
}
}
Expand All @@ -123,18 +131,21 @@ impl ConnectorBuilder {
}

#[inline]
pub(crate) fn set_keepalive(&mut self, dur: Option<Duration>) {
pub(crate) fn keepalive(mut self, dur: Option<Duration>) -> ConnectorBuilder {
self.http.set_keepalive(dur);
self
}

#[inline]
pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
pub(crate) fn timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
self.timeout = timeout;
self
}

#[inline]
pub(crate) fn set_verbose(&mut self, enabled: bool) {
pub(crate) fn verbose(mut self, enabled: bool) -> ConnectorBuilder {
self.verbose.0 = enabled;
self
}
}

Expand All @@ -145,22 +156,35 @@ pub(crate) enum Connector {
// at least one custom layer along with maybe an outer timeout layer
// from `builder.connect_timeout()`
WithLayers {
inner_tls: InnerTLS,
layers: Vec<BoxedConnectorLayer>,
service: BoxedConnectorService,
base_service: ConnectorService,
},
}

impl Connector {
#[inline]
pub(crate) fn set_connector(&mut self, mut connector: BoringTlsConnector) {
match self {
Connector::Simple(service) => {
std::mem::swap(&mut service.tls, &mut InnerTLS::Simple(connector));
std::mem::swap(&mut service.tls, &mut connector);
}
Connector::WithLayers { inner_tls, .. } => {
if let InnerTLS::WithSharedState(tls) = inner_tls {
std::mem::swap(&mut *tls.write(), &mut connector);
}
Connector::WithLayers {
layers,
base_service,
..
} => {
let mut connector = ConnectorBuilder::new(
base_service.http.clone(),
connector,
base_service.nodelay,
base_service.tls_info,
)
.timeout(base_service.timeout)
.keepalive(base_service.keepalive)
.verbose(base_service.verbose.0)
.build(std::mem::take(layers));

std::mem::swap(self, &mut connector);
}
}
}
Expand All @@ -186,32 +210,17 @@ impl Service<Dst> for Connector {
}
}

#[derive(Clone)]
pub(crate) enum InnerTLS {
Simple(BoringTlsConnector),
WithSharedState(Arc<RwLock<BoringTlsConnector>>),
}

impl InnerTLS {
#[inline(always)]
fn get_tls(&self) -> BoringTlsConnector {
match self {
InnerTLS::Simple(tls) => tls.clone(),
InnerTLS::WithSharedState(tls) => tls.read().clone(),
}
}
}

#[derive(Clone)]
pub(crate) struct ConnectorService {
http: HttpConnector,
tls: InnerTLS,
tls: BoringTlsConnector,
verbose: verbose::Wrapper,
/// When there is a single timeout layer and no other layers,
/// we embed it directly inside our base Service::call().
/// This lets us avoid an extra `Box::pin` indirection layer
/// since `tokio::time::Timeout` is `Unpin`
timeout: Option<Duration>,
keepalive: Option<Duration>,
nodelay: bool,
tls_info: bool,
}
Expand All @@ -237,7 +246,7 @@ impl ConnectorService {
.alpn_protos(dst.alpn_protos())
.interface(dst.take_interface())
.addresses(dst.take_addresses())
.build(self.tls.get_tls());
.build(self.tls.clone());

log::trace!("socks HTTPS over proxy");
let host = dst.host().ok_or(crate::error::uri_bad_host())?;
Expand Down Expand Up @@ -280,7 +289,7 @@ impl ConnectorService {
.alpn_protos(dst.alpn_protos())
.interface(dst.take_interface())
.addresses(dst.take_addresses())
.build(self.tls.get_tls());
.build(self.tls);
let io = http.call(dst.into()).await?;

if let MaybeHttpsStream::Https(stream) = io {
Expand Down Expand Up @@ -327,7 +336,7 @@ impl ConnectorService {
.alpn_protos(dst.alpn_protos())
.interface(dst.take_interface())
.addresses(dst.take_addresses())
.build(self.tls.get_tls());
.build(self.tls);

let host = dst.host().ok_or(crate::error::uri_bad_host())?;
let port = dst.port_u16().unwrap_or(443);
Expand Down
Loading