Skip to content

Commit

Permalink
fix(proxy): forward notifications from authentication (#9948)
Browse files Browse the repository at this point in the history
Fixes neondatabase/cloud#20973. 

This refactors `connect_raw` in order to return direct access to the
delayed notices.

I cannot find a way to test this with psycopg2 unfortunately, although
testing it with psql does return the expected results.
  • Loading branch information
conradludgate authored and awarus committed Dec 5, 2024
1 parent 9ce0dd4 commit a750c14
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 58 deletions.
6 changes: 6 additions & 0 deletions libs/pq_proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ pub enum BeMessage<'a> {
/// Batch of interpreted, shard filtered WAL records,
/// ready for the pageserver to ingest
InterpretedWalRecords(InterpretedWalRecordsBody<'a>),

Raw(u8, &'a [u8]),
}

/// Common shorthands.
Expand Down Expand Up @@ -754,6 +756,10 @@ impl BeMessage<'_> {
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
match message {
BeMessage::Raw(code, data) => {
buf.put_u8(*code);
write_body(buf, |b| b.put_slice(data))
}
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
Expand Down
4 changes: 4 additions & 0 deletions libs/proxy/postgres-protocol2/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,10 @@ impl NoticeResponseBody {
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}

pub fn as_bytes(&self) -> &[u8] {
&self.storage
}
}

pub struct NotificationResponseBody {
Expand Down
8 changes: 4 additions & 4 deletions libs/proxy/tokio-postgres2/src/cancel_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use tokio::net::TcpStream;
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
}

impl CancelToken {
Expand Down
13 changes: 5 additions & 8 deletions libs/proxy/tokio-postgres2/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl InnerClient {
}

#[derive(Clone)]
pub(crate) struct SocketConfig {
pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
Expand All @@ -152,7 +152,7 @@ pub(crate) struct SocketConfig {
pub struct Client {
inner: Arc<InnerClient>,

socket_config: Option<SocketConfig>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
Expand All @@ -161,6 +161,7 @@ pub struct Client {
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
Expand All @@ -172,7 +173,7 @@ impl Client {
buffer: Default::default(),
}),

socket_config: None,
socket_config,
ssl_mode,
process_id,
secret_key,
Expand All @@ -188,10 +189,6 @@ impl Client {
&self.inner
}

pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
self.socket_config = Some(socket_config);
}

/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
Expand Down Expand Up @@ -412,7 +409,7 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: self.socket_config.clone(),
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
Expand Down
6 changes: 2 additions & 4 deletions libs/proxy/tokio-postgres2/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::connect::connect;
use crate::connect_raw::connect_raw;
use crate::connect_raw::RawConnection;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
Expand Down Expand Up @@ -485,14 +486,11 @@ impl Config {
connect(tls, self).await
}

/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
pub async fn connect_raw<S, T>(
&self,
stream: S,
tls: T,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
Expand Down
42 changes: 34 additions & 8 deletions libs/proxy/tokio-postgres2/src/connect.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use postgres_protocol2::message::backend::Message;
use std::io;
use std::task::Poll;
use tokio::net::TcpStream;
use tokio::sync::mpsc;

pub async fn connect<T>(
mut tls: T,
Expand Down Expand Up @@ -60,7 +63,36 @@ where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host, port, config.connect_timeout).await?;
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;

let socket_config = SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
};

let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(
sender,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);

// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();

let mut connection = Connection::new(stream, delayed, parameters, receiver);

if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query_raw("SHOW transaction_read_only");
Expand Down Expand Up @@ -102,11 +134,5 @@ where
}
}

client.set_socket_config(SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
});

Ok((client, connection))
}
37 changes: 22 additions & 15 deletions libs/proxy/tokio-postgres2/src/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,26 @@ use crate::config::{self, AuthKeys, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
use crate::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;

pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
delayed_notice: Vec<NoticeResponseBody>,
}

impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
Expand Down Expand Up @@ -78,11 +77,19 @@ where
}
}

pub struct RawConnection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pub parameters: HashMap<String, String>,
pub delayed_notice: Vec<NoticeResponseBody>,
pub process_id: i32,
pub secret_key: i32,
}

pub async fn connect_raw<S, T>(
stream: S,
tls: T,
config: &Config,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
Expand All @@ -97,18 +104,20 @@ where
},
),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
delayed_notice: Vec::new(),
};

startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;

let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);

Ok((client, connection))
Ok(RawConnection {
stream: stream.inner,
parameters,
delayed_notice: stream.delayed_notice,
process_id,
secret_key,
})
}

async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
Expand Down Expand Up @@ -347,9 +356,7 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Expand Down
5 changes: 3 additions & 2 deletions libs/proxy/tokio-postgres2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
#![warn(rust_2018_idioms, clippy::all)]

pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
Expand Down
38 changes: 28 additions & 10 deletions proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::client::danger::ServerCertVerifier;
use rustls::crypto::ring;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{CancelToken, RawConnection};
use tracing::{debug, error, info, warn};

use crate::auth::parse_endpoint_param;
Expand Down Expand Up @@ -277,6 +279,8 @@ pub(crate) struct PostgresConnection {
pub(crate) cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
/// Notices received from compute after authenticating
pub(crate) delayed_notice: Vec<NoticeResponseBody>,

_guage: NumDbConnectionsGuard<'static>,
}
Expand Down Expand Up @@ -322,10 +326,19 @@ impl ConnCfg {

// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
let connection = self.0.connect_raw(stream, tls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();

let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;

tracing::Span::current().record("pid", tracing::field::display(process_id));
let stream = stream.into_inner();

// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
Expand All @@ -334,18 +347,23 @@ impl ConnCfg {
self.0.get_ssl_mode()
);

// This is very ugly but as of now there's no better way to
// extract the connection parameters from tokio-postgres' connection.
// TODO: solve this problem in a more elegant manner (e.g. the new library).
let params = connection.parameters;

// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.0.get_ssl_mode(),
process_id,
secret_key,
},
vec![],
);

let connection = PostgresConnection {
stream,
params,
params: parameters,
delayed_notice,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
Expand Down
8 changes: 5 additions & 3 deletions proxy/src/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ pub(crate) async fn prepare_client_connection<P>(
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());

// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
}

// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
Expand Down
Loading

0 comments on commit a750c14

Please sign in to comment.