Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add backpressure to publish #1286

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ serde = { version = "1.0.184", features = ["derive"] }
serde_json = "1.0.104"
serde_repr = "0.1.16"
tokio = { version = "1.36", features = ["macros", "rt", "fs", "net", "sync", "time", "io-util"] }
tokio-stream = "0.1"
url = { version = "2"}
tokio-rustls = { version = "0.26", default-features = false }
rustls-pemfile = "2"
Expand Down
236 changes: 215 additions & 21 deletions async-nats/src/jetstream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ use crate::subject::ToSubject;
use crate::{header, Client, Command, HeaderMap, HeaderValue, Message, StatusCode};
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::{Future, TryFutureExt};
use futures::{Future, StreamExt, TryFutureExt};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{self, json};
use std::borrow::Borrow;
use std::fmt::Debug;
use std::fmt::Display;
use std::future::IntoFuture;
use std::pin::Pin;
use std::str::from_utf8;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, TryAcquireError};
use tokio::time::Duration;
use tokio_stream::wrappers::ReceiverStream;
use tracing::debug;

use super::consumer::{self, Consumer, FromConsumer, IntoConsumerConfig};
Expand All @@ -54,36 +57,200 @@ pub struct Context {
pub(crate) client: Client,
pub(crate) prefix: String,
pub(crate) timeout: Duration,
pub(crate) max_ack_semaphore: Arc<tokio::sync::Semaphore>,
pub(crate) acker_task: Arc<tokio::task::JoinHandle<()>>,
pub(crate) ack_sender:
tokio::sync::mpsc::Sender<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>,
}

impl Context {
pub(crate) fn new(client: Client) -> Context {
Context {
client,
fn spawn_acker(
rx: tokio::sync::mpsc::Receiver<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>,
ack_timeout: Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let stream = ReceiverStream::new(rx);
stream
.for_each_concurrent(None, |(subscription, permit)| async move {
tokio::time::timeout(ack_timeout, subscription).await.ok();
drop(permit);
})
.await;
})
}

impl Drop for Context {
fn drop(&mut self) {
self.acker_task.abort();
}
}

use std::marker::PhantomData;

#[derive(Debug, Default)]
pub struct Yes;
#[derive(Debug, Default)]
pub struct No;

pub trait ToAssign: Debug {}

impl ToAssign for Yes {}
impl ToAssign for No {}

/// A builder for [Context]. Beyond what can be set by standard constructor, it allows tweaking
/// pending publish ack backpressure settings.
/// # Examples
/// ```no_run
/// # use async_nats::jetstream::context::ContextBuilder;
/// # use async_nats::Client;
/// # use std::time::Duration;
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let context = ContextBuilder::new()
/// .timeout(Duration::from_secs(5))
/// .api_prefix("MY.JS.API")
/// .max_ack_inflight(1000)
/// .build(client);
/// # Ok(())
/// # }
/// ```
///
pub struct ContextBuilder<PREFIX: ToAssign> {
prefix: String,
timeout: Duration,
semaphore_capacity: usize,
ack_timeout: Duration,
_phantom: PhantomData<PREFIX>,
}

impl Default for ContextBuilder<Yes> {
fn default() -> Self {
ContextBuilder {
prefix: "$JS.API".to_string(),
timeout: Duration::from_secs(5),
semaphore_capacity: 50_000,
ack_timeout: Duration::from_secs(30),
_phantom: PhantomData {},
}
}
}

pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout
impl ContextBuilder<Yes> {
/// Create a new [ContextBuilder] with default settings.
pub fn new() -> ContextBuilder<Yes> {
ContextBuilder::default()
}
}

pub(crate) fn with_prefix<T: ToString>(client: Client, prefix: T) -> Context {
Context {
client,
prefix: prefix.to_string(),
timeout: Duration::from_secs(5),
impl ContextBuilder<Yes> {
/// Set the prefix for the JetStream API.
pub fn api_prefix<T: Into<String>>(self, prefix: T) -> ContextBuilder<No> {
ContextBuilder {
prefix: prefix.into(),
timeout: self.timeout,
semaphore_capacity: self.semaphore_capacity,
ack_timeout: self.ack_timeout,
_phantom: PhantomData,
}
}

pub(crate) fn with_domain<T: AsRef<str>>(client: Client, domain: T) -> Context {
/// Set the domain for the JetStream API. Domain is the middle part of standard API prefix:
/// $JS.{domain}.API.
pub fn domain<T: Into<String>>(self, domain: T) -> ContextBuilder<No> {
ContextBuilder {
prefix: format!("$JS.{}.API", domain.into()),
timeout: self.timeout,
semaphore_capacity: self.semaphore_capacity,
ack_timeout: self.ack_timeout,
_phantom: PhantomData,
}
}
}

impl<PREFIX> ContextBuilder<PREFIX>
where
PREFIX: ToAssign,
{
/// Set the timeout for all JetStream API requests.
pub fn timeout(self, timeout: Duration) -> ContextBuilder<Yes>
where
Yes: ToAssign,
{
ContextBuilder {
prefix: self.prefix,
timeout,
semaphore_capacity: self.semaphore_capacity,
ack_timeout: self.ack_timeout,
_phantom: PhantomData,
}
}

/// Sets the maximum time client waits for acks from the server when default backpressure is
/// used.
pub fn ack_timeout(self, ack_timeout: Duration) -> ContextBuilder<Yes>
where
Yes: ToAssign,
{
ContextBuilder {
prefix: self.prefix,
timeout: self.timeout,
semaphore_capacity: self.semaphore_capacity,
ack_timeout,
_phantom: PhantomData,
}
}

/// Sets the maximum number of pending acks that can be in flight at any given time.
/// If limit is reached, `publish` throws an error.
pub fn max_ack_inflight(self, capacity: usize) -> ContextBuilder<Yes>
where
Yes: ToAssign,
{
ContextBuilder {
prefix: self.prefix,
timeout: self.timeout,
semaphore_capacity: capacity,
ack_timeout: self.ack_timeout,
_phantom: PhantomData,
}
}

/// Build the [Context] with the given settings.
pub fn build(self, client: Client) -> Context {
let (tx, rx) = tokio::sync::mpsc::channel::<(
oneshot::Receiver<Message>,
OwnedSemaphorePermit,
)>(self.semaphore_capacity);
let acker_task = Arc::new(spawn_acker(rx, self.ack_timeout));
Context {
client,
prefix: format!("$JS.{}.API", domain.as_ref()),
timeout: Duration::from_secs(5),
prefix: self.prefix,
timeout: self.timeout,
max_ack_semaphore: Arc::new(tokio::sync::Semaphore::new(self.semaphore_capacity)),
acker_task,
ack_sender: tx,
}
}
}

impl Context {
pub(crate) fn new(client: Client) -> Context {
ContextBuilder::default().build(client)
}

pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout
}

pub(crate) fn with_prefix<T: ToString>(client: Client, prefix: T) -> Context {
ContextBuilder::new()
.api_prefix(prefix.to_string())
.build(client)
}

pub(crate) fn with_domain<T: AsRef<str>>(client: Client, domain: T) -> Context {
ContextBuilder::new().domain(domain.as_ref()).build(client)
}

/// Publishes [jetstream::Message][super::message::Message] to the [Stream] without waiting for
/// acknowledgment from the server that the message has been successfully delivered.
Expand Down Expand Up @@ -192,6 +359,16 @@ impl Context {
subject: S,
publish: Publish,
) -> Result<PublishAckFuture, PublishError> {
let permit =
self.max_ack_semaphore
.clone()
.try_acquire_owned()
.map_err(|err| match err {
TryAcquireError::NoPermits => {
PublishError::new(PublishErrorKind::MaxAckPending)
}
_ => PublishError::with_source(PublishErrorKind::Other, err),
})?;
let subject = subject.to_subject();
let (sender, receiver) = oneshot::channel();

Expand All @@ -215,7 +392,9 @@ impl Context {

Ok(PublishAckFuture {
timeout: self.timeout,
subscription: receiver,
subscription: Some(receiver),
permit: Some(permit),
tx: self.ack_sender.clone(),
})
}

Expand Down Expand Up @@ -1212,6 +1391,7 @@ pub enum PublishErrorKind {
WrongLastSequence,
TimedOut,
BrokenPipe,
MaxAckPending,
Other,
}

Expand All @@ -1224,6 +1404,7 @@ impl Display for PublishErrorKind {
Self::BrokenPipe => write!(f, "broken pipe"),
Self::WrongLastMessageId => write!(f, "wrong last message id"),
Self::WrongLastSequence => write!(f, "wrong last sequence"),
Self::MaxAckPending => write!(f, "max ack pending reached"),
}
}
}
Expand All @@ -1233,12 +1414,25 @@ pub type PublishError = Error<PublishErrorKind>;
#[derive(Debug)]
pub struct PublishAckFuture {
timeout: Duration,
subscription: oneshot::Receiver<Message>,
subscription: Option<oneshot::Receiver<Message>>,
permit: Option<OwnedSemaphorePermit>,
tx: mpsc::Sender<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>,
}

impl Drop for PublishAckFuture {
fn drop(&mut self) {
match (self.subscription.take(), self.permit.take()) {
(Some(sub), Some(permit)) => {
self.tx.try_send((sub, permit)).ok();
}
_ => {}
}
}
}

impl PublishAckFuture {
async fn next_with_timeout(self) -> Result<PublishAck, PublishError> {
let next = tokio::time::timeout(self.timeout, self.subscription)
async fn next_with_timeout(mut self) -> Result<PublishAck, PublishError> {
let next = tokio::time::timeout(self.timeout, self.subscription.take().unwrap())
.await
.map_err(|_| PublishError::new(PublishErrorKind::TimedOut))?;
next.map_or_else(
Expand Down
24 changes: 24 additions & 0 deletions async-nats/tests/jetstream_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3654,4 +3654,28 @@ mod jetstream {
.await
.expect_err("should fail but not panic because of lack of server info");
}

#[tokio::test]
async fn test_async_publish_max_ack_pending() {
let server = nats_server::run_server("tests/configs/jetstream.conf");
let client = async_nats::connect(server.client_url()).await.unwrap();

let jetstream = async_nats::jetstream::new(client);

jetstream
.create_stream(stream::Config {
name: "events".to_string(),
subjects: vec!["events".to_string()],
..Default::default()
})
.await
.unwrap();

for i in 0..100_000 {
jetstream
.publish("events", format!("{i}").into())
.await
.unwrap();
}
}
}
Loading