Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: pg17 fixes #8321

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions libs/postgres_backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,17 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
FeStartupPacket::SslRequest { direct } => {
debug!("SSL requested");

self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if !direct {
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
} else if !have_tls {
return Err(QueryError::Other(anyhow::anyhow!(
"direct SSL negotiation but no TLS support"
)));
}

if have_tls {
self.start_tls().await?;
Expand Down
6 changes: 3 additions & 3 deletions libs/pq_proto/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl ConnectionError {
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
pub stream: S,
pub read_buf: BytesMut,
pub write_buf: BytesMut,
}

impl<S> Framed<S> {
Expand Down
91 changes: 71 additions & 20 deletions libs/pq_proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,39 @@ pub enum FeMessage {
PasswordMessage(Bytes),
}

#[derive(Clone, Copy, PartialEq, PartialOrd)]
pub struct ProtocolVersion(u32);

impl ProtocolVersion {
pub const fn new(major: u16, minor: u16) -> Self {
Self((major as u32) << 16 | minor as u32)
}
pub const fn minor(self) -> u16 {
self.0 as u16
}
pub const fn major(self) -> u16 {
(self.0 >> 16) as u16
}
}

impl fmt::Debug for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entry(&self.major())
.entry(&self.minor())
.finish()
}
}

#[derive(Debug)]
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest,
SslRequest {
direct: bool,
},
GssEncRequest,
StartupMessage {
major_version: u32,
minor_version: u32,
version: ProtocolVersion,
params: StartupMessageParams,
},
}
Expand Down Expand Up @@ -301,11 +326,23 @@ impl FeStartupPacket {
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);

// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
// First byte indicates standard SSL handshake message
// (It can't be a Postgres startup length because in network byte order
// that would be a startup packet hundreds of megabytes long)
if buf.first() == Some(&0x16) {
return Ok(Some(FeStartupPacket::SslRequest { direct: true }));
}

// need at least 4 bytes with packet len
if buf.len() < 4 {
Expand Down Expand Up @@ -338,12 +375,10 @@ impl FeStartupPacket {
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len

let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let request_code = ProtocolVersion(msg.get_u32());
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
let message = match request_code {
CANCEL_REQUEST_CODE => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
Expand All @@ -355,21 +390,22 @@ impl FeStartupPacket {
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
NEGOTIATE_SSL_CODE => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
FeStartupPacket::SslRequest { direct: false }
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
NEGOTIATE_GSS_CODE => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
"Unrecognized request code {}",
version.minor()
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
version => {
// StartupMessage

let s = str::from_utf8(&msg).map_err(|_e| {
Expand All @@ -382,8 +418,7 @@ impl FeStartupPacket {
})?;

FeStartupPacket::StartupMessage {
major_version,
minor_version,
version,
params: StartupMessageParams {
params: msg.slice_ref(s.as_bytes()),
},
Expand Down Expand Up @@ -522,6 +557,10 @@ pub enum BeMessage<'a> {
RowDescription(&'a [RowDescriptor<'a>]),
XLogData(XLogDataBody<'a>),
NoticeResponse(&'a str),
NegotiateProtocolVersion {
version: ProtocolVersion,
options: &'a [&'a str],
},
KeepAlive(WalSndKeepAlive),
}

Expand Down Expand Up @@ -945,6 +984,18 @@ impl<'a> BeMessage<'a> {
buf.put_u8(u8::from(req.request_reply));
});
}

BeMessage::NegotiateProtocolVersion { version, options } => {
buf.put_u8(b'v');
write_body(buf, |buf| {
buf.put_u32(version.0);
buf.put_u32(options.len() as u32);
for option in options.iter() {
write_cstr(option, buf)?;
}
Ok(())
})?
}
}
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion proxy/src/bin/pg_sni_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
use pq_proto::FeStartupPacket::*;

match msg {
SslRequest => {
SslRequest { direct: false } => {
stream
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
.await?;

// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.

Expand Down
12 changes: 8 additions & 4 deletions proxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ impl TlsConfig {
}
}

/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";

/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
Expand Down Expand Up @@ -111,16 +114,17 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);

// allow TLS 1.2 to be compatible with older client libraries
let config = rustls::ServerConfig::builder_with_protocol_versions(&[
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone())
.into();
.with_cert_resolver(cert_resolver.clone());

config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];

Ok(TlsConfig {
config,
config: Arc::new(config),
common_names,
cert_resolver,
})
Expand Down
Loading
Loading