Skip to content

Commit

Permalink
Revert "use chainrw"
Browse files Browse the repository at this point in the history
This reverts commit 2dd9d1a.
  • Loading branch information
conradludgate committed Jul 9, 2024
1 parent fdceee4 commit aa023f8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
8 changes: 1 addition & 7 deletions proxy/src/protocol2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@ pin_project! {
pub struct ChainRW<T> {
#[pin]
pub inner: T,
pub buf: BytesMut,
}
}

impl<T> ChainRW<T> {
pub fn with_buf(inner: T, buf: BytesMut) -> Self {
Self { inner, buf }
buf: BytesMut,
}
}

Expand Down
4 changes: 1 addition & 3 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;

use crate::protocol2::ChainRW;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
Expand Down Expand Up @@ -247,8 +246,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, ChainRW<S>>>, ClientRequestError>
{
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol,
"handling interactive connection from client"
Expand Down
47 changes: 28 additions & 19 deletions proxy/src/proxy/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bytes::BytesMut;
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
StartupMessageParams,
Expand All @@ -12,7 +12,6 @@ use crate::{
config::{TlsConfig, PG_ALPN_PROTOCOL},
error::ReportableError,
metrics::Metrics,
protocol2::ChainRW,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
Expand Down Expand Up @@ -71,17 +70,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<ChainRW<S>>, HandshakeError> {
) -> Result<HandshakeData<S>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);

const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);

let mut stream = PqStream::new(Stream::from_raw(ChainRW::with_buf(
stream,
BytesMut::default(),
)));
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
Expand Down Expand Up @@ -112,29 +108,42 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
write_buf,
} = stream.framed;

let Stream::Raw { mut raw } = raw else {
let Stream::Raw { raw } = raw else {
return Err(HandshakeError::StreamUpgradeError(
StreamUpgradeError::AlreadyTls,
));
};

// read_buf might contain the TLS ClientHello, so make sure we include it.
let empty_buf = std::mem::replace(&mut raw.buf, read_buf);
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
.accept_with(raw, |session| {
// push the early data to the tls session
while !read_buf.get_ref().is_empty() {
match session.read_tls(&mut read_buf) {
Ok(_) => {}
Err(e) => {
res = Err(e);
break;
}
}
}
});

res?;

let read_buf = read_buf.into_inner();
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}

let acceptor = tokio_rustls::TlsAcceptor::from(tls.to_server_config());
let mut tls_stream = acceptor.accept(raw).await.inspect_err(|_| {
let tls_stream = accept.await.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc()
}
})?;

let (io, conn_info) = tls_stream.get_mut();

// The read_buf should not contain any more application data sent before the TLS handshake.
let read_buf = std::mem::replace(&mut io.buf, empty_buf);
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}
let conn_info = tls_stream.get_ref().1;

// check the ALPN, if exists, as required.
match conn_info.alpn_protocol() {
Expand Down

0 comments on commit aa023f8

Please sign in to comment.