Skip to content

Commit

Permalink
Resolve Comments
Browse files Browse the repository at this point in the history
    - Remove `hex` from root `Cargo.toml`
    - Make `hmac` crate optional
    - Clean up checking mechanisms for "SCRAM-SHA-256"
    - Use `str::from_utf8` instead of `String::from_utf8_lossyf
    - Update `Sasl*Response` structs be tuple structs
    - Factor out `len` in `SaslInitialResponse.encode()`
    - Use `protocol_err` instead of `expect` when constructing `Hmacf
      instances
    - Remove `it_connects_to_database_user` test as it was too fragile
    - Move `sasl_auth` function into `postgres/connection` as it more
      related to `Connection` rather than `protocl`
    - Return an error when decoding base64 salt rather than panicing
      in `Authentication::SaslContinue`
  • Loading branch information
janaakhterov committed Jan 11, 2020
1 parent 32250f5 commit 96c23a8
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 231 deletions.
12 changes: 10 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ authors = [
[features]
default = []
unstable = []
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]

[dependencies]
Expand All @@ -28,7 +28,7 @@ chrono = { version = "0.4.10", default-features = false, features = [ "clock" ],
digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
futures-core = { version = "0.3.1", default-features = false }
futures-util = { version = "0.3.1", default-features = false }
generic-array = { version = "0.12.3", default-features = false, optional = true }
generic-array = { version = "0.13.2", default-features = false, optional = true }
log = { version = "0.4.8", default-features = false }
md-5 = { version = "0.8.0", default-features = false, optional = true }
memchr = { version = "2.2.1", default-features = false }
Expand All @@ -38,8 +38,7 @@ sha-1 = { version = "0.8.1", default-features = false, optional = true }
sha2 = { version = "0.8.0", default-features = false, optional = true }
url = { version = "2.1.0", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true }
hex = "0.4.0"
hmac = "0.7.1"
hmac = { version = "0.7.1", default-features = false, optional = true }

[dev-dependencies]
matches = "0.1.8"
Expand Down
196 changes: 170 additions & 26 deletions sqlx-core/src/postgres/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ use futures_core::future::BoxFuture;
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::io::{Buf, BufStream};
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId};
use crate::postgres::protocol::{self, Decode, Encode, Message, StatementId, SaslResponse, SaslInitialResponse, hi, Authentication};
use crate::postgres::PgError;
use crate::url::Url;
use std::ops::Deref;
use sha2::{Sha256, Digest};
use hmac::{Mac, Hmac};
use crate::Result;
use rand::Rng;

/// An asynchronous connection to a [Postgres] database.
///
Expand Down Expand Up @@ -38,7 +41,7 @@ pub struct PgConnection {

impl PgConnection {
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(&mut self, url: Url) -> crate::Result<()> {
async fn startup(&mut self, url: Url) -> Result<()> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
Expand Down Expand Up @@ -94,26 +97,21 @@ impl PgConnection {
}

protocol::Authentication::Sasl { mechanisms } => {
let mechanism = (*mechanisms)
.get(0)
.ok_or(protocol_err!(
match mechanisms.get(0).map(|m| &**m) {
Some("SCRAM-SHA-256") => {
sasl_auth(
self,
username,
url.password().unwrap_or_default(),
)
.await?;
}

_ => return Err(protocol_err!(
"Expected mechanisms SCRAM-SHA-256, but received {:?}",
mechanisms
))?
.deref();
if "SCRAM-SHA-256" == &*mechanism {
protocol::sasl_auth(
self,
username,
url.password().unwrap_or_default(),
)
.await
} else {
Err(protocol_err!(
"Expected mechanisms SCRAM-SHA-256, but received {:?}",
mechanisms
))?
}?;
).into()),
}
}

auth => {
Expand Down Expand Up @@ -146,7 +144,7 @@ impl PgConnection {
}

// https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut self) -> crate::Result<()> {
async fn terminate(mut self) -> Result<()> {
protocol::Terminate.encode(self.stream.buffer_mut());

self.stream.flush().await?;
Expand All @@ -156,7 +154,7 @@ impl PgConnection {
}

// Wait and return the next message to be received from Postgres.
pub(super) async fn receive(&mut self) -> crate::Result<Option<Message>> {
pub(super) async fn receive(&mut self) -> Result<Option<Message>> {
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
Expand Down Expand Up @@ -222,7 +220,7 @@ impl PgConnection {
}

impl PgConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
pub(super) async fn open(url: Result<Url>) -> Result<Self> {
let url = url?;
let stream = TcpStream::connect((url.host(), url.port(5432))).await?;
let mut self_ = Self {
Expand All @@ -242,15 +240,161 @@ impl PgConnection {
}

impl Connection for PgConnection {
fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
fn open<T>(url: T) -> BoxFuture<'static, Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
}

fn close(self) -> BoxFuture<'static, crate::Result<()>> {
fn close(self) -> BoxFuture<'static, Result<()>> {
Box::pin(self.terminate())
}
}

static GS2_HEADER: &'static str = "n,,";
static CHANNEL_ATTR: &'static str = "c";
static USERNAME_ATTR: &'static str = "n";
static CLIENT_PROOF_ATTR: &'static str = "p";
static NONCE_ATTR: &'static str = "r";

// Nonce generator
// Nonce is a sequence of random printable bytes
fn nonce() -> String {
let mut rng = rand::thread_rng();
let count = rng.gen_range(64, 128);
// printable = %x21-2B / %x2D-7E
// ;; Printable ASCII except ",".
// ;; Note that any "printable" is also
// ;; a valid "value".
let nonce: String = std::iter::repeat(())
.map(|()| {
let mut c = rng.gen_range(0x21, 0x7F) as u8;

while c == 0x2C {
c = rng.gen_range(0x21, 0x7F) as u8;
}

c
})
.take(count)
.map(|c| c as char)
.collect();

rng.gen_range(32, 128);
format!("{}={}", NONCE_ATTR, nonce)
}

// Performs authenticiton using Simple Authentication Security Layer (SASL) which is what
// Postgres uses
async fn sasl_auth<T: AsRef<str>>(
conn: &mut PgConnection,
username: T,
password: T,
) -> Result<()> {
// channel-binding = "c=" base64
let channel_binding = format!("{}={}", CHANNEL_ATTR, base64::encode(GS2_HEADER));
// "n=" saslname ;; Usernames are prepared using SASLprep.
let username = format!("{}={}", USERNAME_ATTR, username.as_ref());
// nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
let nonce = nonce();
let client_first_message_bare =
format!("{username},{nonce}", username = username, nonce = nonce);
// client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
let client_first_message = format!(
"{gs2_header}{client_first_message_bare}",
gs2_header = GS2_HEADER,
client_first_message_bare = client_first_message_bare
);

SaslInitialResponse(&client_first_message)
.encode(conn.stream.buffer_mut());
conn.stream.flush().await?;

let server_first_message = conn.receive().await?;

if let Some(Message::Authentication(auth)) = server_first_message {
if let Authentication::SaslContinue(sasl) = *auth {
let server_first_message = sasl.data;

// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(password.as_ref(), &sasl.salt, sasl.iter_count)?;

// ClientKey := HMAC(SaltedPassword, "Client Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|_| protocol_err!("HMAC can take key of any size"))?;
mac.input(b"Client Key");
let client_key = mac.result().code();

// StoredKey := H(ClientKey)
let mut hasher = Sha256::new();
hasher.input(client_key);
let stored_key = hasher.result();

// String::from_utf8_lossy should never fail because Postgres requires
// the nonce to be all printable characters except ','
let client_final_message_wo_proof = format!(
"{channel_binding},r={nonce}",
channel_binding = channel_binding,
nonce = String::from_utf8_lossy(&sasl.nonce)
);

// AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
let auth_message = format!("{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
client_first_message_bare = client_first_message_bare,
server_first_message = server_first_message,
client_final_message_wo_proof = client_final_message_wo_proof);

// ClientSignature := HMAC(StoredKey, AuthMessage)
let mut mac =
Hmac::<Sha256>::new_varkey(&stored_key).expect("HMAC can take key of any size");
mac.input(&auth_message.as_bytes());
let client_signature = mac.result().code();

// ClientProof := ClientKey XOR ClientSignature
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(&a, &b)| a ^ b)
.collect();

// ServerKey := HMAC(SaltedPassword, "Server Key")
let mut mac = Hmac::<Sha256>::new_varkey(&salted_password)
.map_err(|_| protocol_err!("HMAC can take key of any size"))?;
mac.input(b"Server Key");
let server_key = mac.result().code();

// ServerSignature := HMAC(ServerKey, AuthMessage)
let mut mac =
Hmac::<Sha256>::new_varkey(&server_key).expect("HMAC can take key of any size");
mac.input(&auth_message.as_bytes());
let _server_signature = mac.result().code();

// client-final-message = client-final-message-without-proof "," proof
let client_final_message = format!(
"{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
client_final_message_wo_proof = client_final_message_wo_proof,
client_proof_attr = CLIENT_PROOF_ATTR,
client_proof = base64::encode(&client_proof)
);

SaslResponse(&client_final_message)
.encode(conn.stream.buffer_mut());
conn.stream.flush().await?;
let _server_final_response = conn.receive().await?;

Ok(())
} else {
Err(protocol_err!(
"Expected Authentication::SaslContinue, but received {:?}",
auth
))?
}
} else {
Err(protocol_err!(
"Expected Message::Authentication, but received {:?}",
server_first_message
))?
}
}
30 changes: 23 additions & 7 deletions sqlx-core/src/postgres/protocol/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::postgres::protocol::Decode;
use byteorder::NetworkEndian;
use std::borrow::Cow;
use std::io;
use std::str;

#[derive(Debug)]
pub enum Authentication {
Expand Down Expand Up @@ -99,29 +100,44 @@ impl Decode for Authentication {
let mut nonce: Vec<u8> = Vec::new();
let mut iter_count: u32 = 0;

buf.split(|byte| *byte == b',')
let key_value: Vec<(char, &[u8])> = buf
.split(|byte| *byte == b',')
.map(|s| {
let (key, value) = s.split_at(1);
let value = value.split_at(1).1;

(key[0] as char, value)
})
.for_each(|(key, value)| match key {
.collect();

for (key, value) in key_value.iter() {
match key {
's' => salt = value.to_vec(),
'r' => nonce = value.to_vec(),
'i' => {
iter_count = u32::from_str_radix(&String::from_utf8_lossy(&value), 10)
.unwrap_or(0);
let s = str::from_utf8(&value).map_err(|_| {
protocol_err!(
"iteration count in sasl response was not a valid utf8 string"
)
})?;
iter_count = u32::from_str_radix(&s, 10).unwrap_or(0);
}

_ => {}
});
}
}

Authentication::SaslContinue(SaslContinue {
salt: base64::decode(&salt).unwrap(),
salt: base64::decode(&salt).map_err(|_| {
protocol_err!("salt value response from postgres was not base64 encoded")
})?,
nonce,
iter_count,
data: String::from_utf8_lossy(buf).into_owned(),
data: str::from_utf8(buf)
.map_err(|_| {
protocol_err!("SaslContinue response was not a valid utf8 string")
})?
.to_string(),
})
}

Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use flush::Flush;
pub use parse::Parse;
pub use password_message::PasswordMessage;
pub use query::Query;
pub use sasl::{sasl_auth, SaslInitialResponse, SaslResponse};
pub use sasl::{hi, SaslInitialResponse, SaslResponse};
pub use startup_message::StartupMessage;
pub use statement::StatementId;
pub use sync::Sync;
Expand Down
Loading

0 comments on commit 96c23a8

Please sign in to comment.