Skip to content

Commit

Permalink
chore(proxy): enforce single host+port (#9995)
Browse files Browse the repository at this point in the history
proxy doesn't ever provide multiple hosts/ports, so this code adds a lot
of complexity of error handling for no good reason.

(stacked on #9990)
  • Loading branch information
conradludgate authored and awarus committed Dec 5, 2024
1 parent 5519e42 commit fbc8c36
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 141 deletions.
41 changes: 9 additions & 32 deletions libs/proxy/tokio-postgres2/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,41 +146,36 @@ pub enum AuthKeys {
/// ```
#[derive(Clone, PartialEq, Eq)]
pub struct Config {
pub(crate) host: Host,
pub(crate) port: u16,

pub(crate) user: Option<String>,
pub(crate) password: Option<Vec<u8>>,
pub(crate) auth_keys: Option<Box<AuthKeys>>,
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) host: Vec<Host>,
pub(crate) port: Vec<u16>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) replication_mode: Option<ReplicationMode>,
pub(crate) max_backend_message_size: Option<usize>,
}

impl Default for Config {
fn default() -> Config {
Config::new()
}
}

impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
pub fn new(host: String, port: u16) -> Config {
Config {
host: Host::Tcp(host),
port,
user: None,
password: None,
auth_keys: None,
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
host: vec![],
port: vec![],
connect_timeout: None,
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
Expand Down Expand Up @@ -283,32 +278,14 @@ impl Config {
self.ssl_mode
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order.
pub fn host(&mut self, host: &str) -> &mut Config {
self.host.push(Host::Tcp(host.to_string()));
self
}

/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
pub fn get_host(&self) -> &Host {
&self.host
}

/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.port.push(port);
self
}

/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
&self.port
pub fn get_port(&self) -> u16 {
self.port
}

/// Sets the timeout applied to socket-level connection attempts.
Expand Down
38 changes: 9 additions & 29 deletions libs/proxy/tokio-postgres2/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,18 @@ pub async fn connect<T>(
where
T: MakeTlsConnect<TcpStream>,
{
if config.host.is_empty() {
return Err(Error::config("host missing".into()));
}

if config.port.len() > 1 && config.port.len() != config.host.len() {
return Err(Error::config("invalid number of ports".into()));
}

let mut error = None;
for (i, host) in config.host.iter().enumerate() {
let port = config
.port
.get(i)
.or_else(|| config.port.first())
.copied()
.unwrap_or(5432);

let hostname = match host {
Host::Tcp(host) => host.as_str(),
};
let hostname = match &config.host {
Host::Tcp(host) => host.as_str(),
};

let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;

match connect_once(host, port, tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
match connect_once(&config.host, config.port, tls, config).await {
Ok((client, connection)) => Ok((client, connection)),
Err(e) => Err(e),
}

Err(error.unwrap())
}

async fn connect_once<T>(
Expand Down
8 changes: 2 additions & 6 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,8 @@ async fn authenticate(

// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);

ctx.set_dbname(db_info.dbname.into());
ctx.set_user(db_info.user.into());
Expand Down
7 changes: 1 addition & 6 deletions proxy/src/auth/backend/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
cfg.host(&postgres_addr.ip().to_string());
cfg.port(postgres_addr.port());
cfg
},
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
Expand Down
61 changes: 17 additions & 44 deletions proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone, Default)]
#[derive(Clone)]
pub(crate) struct ConnCfg(Box<postgres_client::Config>);

/// Creation and initialization routines.
impl ConnCfg {
pub(crate) fn new() -> Self {
Self::default()
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}

/// Reuse password or auth keys from the other config.
Expand All @@ -124,13 +124,9 @@ impl ConnCfg {
}
}

pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
match self.0.get_hosts() {
[postgres_client::config::Host::Tcp(s)] => Ok(s.into()),
// we should not have multiple address or unix addresses.
_ => Err(WakeComputeError::BadComputeAddress(
"invalid compute address".into(),
)),
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
}
}

Expand Down Expand Up @@ -227,43 +223,20 @@ impl ConnCfg {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let mut connection_error = None;
let ports = self.0.get_ports();
let hosts = self.0.get_hosts();
// the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
if ports.len() > 1 && ports.len() != hosts.len() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"bad compute config, \
ports and hosts entries' count does not match: {:?}",
self.0
),
));
}
let port = self.0.get_port();
let host = self.0.get_host();

for (i, host) in hosts.iter().enumerate() {
let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
let host = match host {
Host::Tcp(host) => host.as_str(),
};

match connect_once(host, *port).await {
Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
Err(err) => {
// We can't throw an error here, as there might be more hosts to try.
warn!("couldn't connect to compute node at {host}:{port}: {err}");
connection_error = Some(err);
}
let host = match host {
Host::Tcp(host) => host.as_str(),
};

match connect_once(host, port).await {
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(err)
}
}

Err(connection_error.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("bad compute config: {:?}", self.0),
)
}))
}
}

Expand Down
10 changes: 5 additions & 5 deletions proxy/src/control_plane/client/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ impl MockControlPlane {
}

async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(postgres_client::config::SslMode::Disable);
let mut config = compute::ConnCfg::new(
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
self.endpoint.port().unwrap_or(5432),
);
config.ssl_mode(postgres_client::config::SslMode::Disable);

let node = NodeInfo {
config,
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/control_plane/client/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ impl NeonControlPlaneClient {
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let mut config = compute::ConnCfg::new(host.to_owned(), port);
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.

let node = NodeInfo {
config,
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/proxy/connect_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, timeout).await)
}
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/proxy/tests/mitm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
Expand Down Expand Up @@ -241,7 +241,7 @@ async fn connect_failure(
Scram::new("password").await?,
));

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(channel_binding)
.user("user")
.dbname("db")
Expand Down
14 changes: 7 additions & 7 deletions proxy/src/proxy/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));

let client_err = postgres_client::Config::new()
let client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
Expand Down Expand Up @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
Expand All @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {

let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
Expand Down Expand Up @@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
Scram::new(password).await?,
));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Require)
.user("user")
.dbname("db")
Expand All @@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
Expand Down Expand Up @@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
.map(char::from)
.collect();

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("user")
.dbname("db")
.password(&password) // no password will match the mocked secret
Expand Down Expand Up @@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism {

fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new(),
config: compute::ConnCfg::new("test".to_owned(), 5432),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
Expand Down
10 changes: 3 additions & 7 deletions proxy/src/serverless/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;

let mut config = (*node_info.config).clone();
Expand Down Expand Up @@ -549,16 +549,12 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;

let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);

let port = *node_info.config.get_ports().first().ok_or_else(|| {
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
"local-proxy port missing on compute address".into(),
))
})?;
let port = node_info.config.get_port();
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Expand Down

0 comments on commit fbc8c36

Please sign in to comment.