Skip to content

Commit

Permalink
Updates from review and from testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
thedodd committed Jan 30, 2020
1 parent 2955d84 commit 267de3b
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions sqlx-core/src/postgres/listen.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::DerefMut;

use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;

Expand All @@ -15,9 +17,9 @@ impl PgConnection {
/// Register this connection as a listener on the specified channel.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
let _ = self.execute(cmd.as_str(), Default::default()).await?;
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = build_listen_all_query(&[channel]);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}

Expand All @@ -28,25 +30,18 @@ impl PgConnection {
mut self,
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<Self>> {
for channel in channels {
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
let _ = self.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}

/// Build a LISTEN query based on the given channel input.
fn build_listen_query(channel: &impl AsRef<str>) -> String {
format!(r#"LISTEN "{}";"#, channel.as_ref())
}
}

impl PgPool {
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
pub async fn listen(&self, channel: &impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
pub async fn listen(&self, channel: impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
let mut conn = self.acquire().await?;
let cmd = PgConnection::build_listen_query(channel);
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
let cmd = build_listen_all_query(&[channel]);
let _ = conn.send(cmd.as_str()).await?;
Ok(PgListener::new(conn))
}

Expand All @@ -56,31 +51,31 @@ impl PgPool {
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<PgPoolConnection>> {
let mut conn = self.acquire().await?;
for channel in channels {
let cmd = PgConnection::build_listen_query(&channel);
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = conn.send(cmd.as_str()).await?;
Ok(PgListener::new(conn))
}
}

impl PgPoolConnection {
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = PgConnection::build_listen_query(channel);
let _ = self.execute(cmd.as_str(), Default::default()).await?;
/// Register this connection as a listener on the specified channel.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = build_listen_all_query(&[channel]);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}

/// Fetch a new connection from the pool and register it as a listener on the specified channels.
/// Register this connection as a listener on all of the specified channels.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen_all(
mut self,
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<Self>> {
for channel in channels {
let cmd = PgConnection::build_listen_query(&channel);
let _ = self.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}
}
Expand All @@ -99,16 +94,15 @@ impl<C> PgListener<C> {

impl<C> PgListener<C>
where
C: AsMut<PgConnection>,
C: DerefMut<Target=PgConnection>,
{
/// Get the next async notification from the database.
pub async fn next(&mut self) -> Result<NotifyMessage> {
loop {
match self.0.as_mut().receive().await? {
match (&mut self.0).receive().await? {
Some(Message::NotificationResponse(notification)) => return Ok(notification.into()),
// TODO: verify with team if this is correct. Looks like the connection being closed will cause an error
// to propagate up from `recevie`, but it would be good to verify with team.
Some(_) | None => continue,
Some(msg) => return Err(protocol_err!("unexpected message received from database {:?}", msg).into()),
None => continue,
}
}
}
Expand Down Expand Up @@ -170,6 +164,7 @@ impl<C: Connection<Database = Postgres>> crate::Executor for PgListener<C> {
}

/// An asynchronous message sent from the database.
#[derive(Debug)]
#[non_exhaustive]
pub struct NotifyMessage {
/// The channel of the notification, which can be thought of as a topic.
Expand All @@ -186,3 +181,30 @@ impl From<Box<NotificationResponse>> for NotifyMessage {
}
}
}

/// Build a query which issues a LISTEN command for each given channel.
fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
channels.into_iter().fold(String::new(), |mut acc, chan| {
acc.push_str(r#"LISTEN ""#);
acc.push_str(chan.as_ref());
acc.push_str(r#"";"#);
acc
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn build_listen_all_query_with_single_channel() {
let output = build_listen_all_query(&["test"]);
assert_eq!(output.as_str(), r#"LISTEN "test";"#);
}

#[test]
fn build_listen_all_query_with_multiple_channels() {
let output = build_listen_all_query(&["channel.0", "channel.1"]);
assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
}
}

0 comments on commit 267de3b

Please sign in to comment.