diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 882f54e40..a21512f72 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -17,17 +17,18 @@ use crate::ServerInfo; use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber}; use crate::error::Error; use bytes::Bytes; -use futures::future::TryFutureExt; use futures::stream::StreamExt; +use futures::{Future, TryFutureExt}; use once_cell::sync::Lazy; use regex::Regex; use std::fmt::Display; +use std::future::IntoFuture; +use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use tokio::sync::mpsc; -use tracing::trace; static VERSION_RE: Lazy = Lazy::new(|| Regex::new(r#"\Av?([0-9]+)\.?([0-9]+)?\.?([0-9]+)?"#).unwrap()); @@ -44,6 +45,63 @@ impl From> for PublishError { } } +#[must_use] +pub struct Publish { + sender: mpsc::Sender, + subject: String, + payload: Bytes, + headers: Option, + respond: Option, +} + +impl Publish { + pub(crate) fn new(sender: mpsc::Sender, subject: String, payload: Bytes) -> Publish { + Publish { + sender, + subject, + payload, + headers: None, + respond: None, + } + } + + pub fn headers(mut self, headers: HeaderMap) -> Publish { + self.headers = Some(headers); + self + } + + pub fn reply(mut self, subject: String) -> Publish { + self.respond = Some(subject); + self + } +} + +impl IntoFuture for Publish { + type Output = Result<(), PublishError>; + type IntoFuture = Pin> + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let sender = self.sender; + let subject = self.subject; + let payload = self.payload; + let respond = self.respond; + let headers = self.headers; + + Box::pin(async move { + sender + .send(Command::Publish { + subject, + payload, + respond, + headers, + }) + .await?; + + Ok(()) + }) + } +} + /// Client is a `Cloneable` handle to NATS connection. /// Client should not be created directly. Instead, one of two methods can be used: /// [crate::connect] and [crate::ConnectOptions::connect] @@ -149,16 +207,8 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn publish(&self, subject: String, payload: Bytes) -> Result<(), PublishError> { - self.sender - .send(Command::Publish { - subject, - payload, - respond: None, - headers: None, - }) - .await?; - Ok(()) + pub fn publish(&self, subject: String, payload: Bytes) -> Publish { + Publish::new(self.sender.clone(), subject, payload) } /// Publish a [Message] with headers to a given subject. @@ -186,14 +236,7 @@ impl Client { headers: HeaderMap, payload: Bytes, ) -> Result<(), PublishError> { - self.sender - .send(Command::Publish { - subject, - payload, - respond: None, - headers: Some(headers), - }) - .await?; + self.publish(subject, payload).headers(headers).await?; Ok(()) } @@ -223,14 +266,7 @@ impl Client { reply: String, payload: Bytes, ) -> Result<(), PublishError> { - self.sender - .send(Command::Publish { - subject, - payload, - respond: Some(reply), - headers: None, - }) - .await?; + self.publish(subject, payload).reply(reply).await?; Ok(()) } @@ -264,13 +300,9 @@ impl Client { headers: HeaderMap, payload: Bytes, ) -> Result<(), PublishError> { - self.sender - .send(Command::Publish { - subject, - payload, - respond: Some(reply), - headers: Some(headers), - }) + self.publish(subject, payload) + .headers(headers) + .reply(reply) .await?; Ok(()) } @@ -286,10 +318,8 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn request(&self, subject: String, payload: Bytes) -> Result { - trace!("request sent to subject: {} ({})", subject, payload.len()); - let request = Request::new().payload(payload); - self.send_request(subject, request).await + pub fn request(&self, subject: String, payload: Bytes) -> Request { + Request::new(self.clone(), subject, payload) } /// Sends the request with headers. @@ -313,65 +343,11 @@ impl Client { headers: HeaderMap, payload: Bytes, ) -> Result { - let request = Request::new().headers(headers).payload(payload); - self.send_request(subject, request).await - } + let message = Request::new(self.clone(), subject, payload) + .headers(headers) + .await?; - /// Sends the request created by the [Request]. - /// - /// # Examples - /// - /// ```no_run - /// # #[tokio::main] - /// # async fn main() -> Result<(), async_nats::Error> { - /// let client = async_nats::connect("demo.nats.io").await?; - /// let request = async_nats::Request::new().payload("data".into()); - /// let response = client.send_request("service".into(), request).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn send_request( - &self, - subject: String, - request: Request, - ) -> Result { - let inbox = request.inbox.unwrap_or_else(|| self.new_inbox()); - let timeout = request.timeout.unwrap_or(self.request_timeout); - let mut sub = self.subscribe(inbox.clone()).await?; - let payload: Bytes = request.payload.unwrap_or_else(Bytes::new); - match request.headers { - Some(headers) => { - self.publish_with_reply_and_headers(subject, inbox, headers, payload) - .await? - } - None => self.publish_with_reply(subject, inbox, payload).await?, - } - self.flush() - .await - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - let request = match timeout { - Some(timeout) => { - tokio::time::timeout(timeout, sub.next()) - .map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err)) - .await? - } - None => sub.next().await, - }; - match request { - Some(message) => { - if message.status == Some(StatusCode::NO_RESPONDERS) { - return Err(RequestError::with_source( - RequestErrorKind::NoResponders, - "no responders", - )); - } - Ok(message) - } - None => Err(RequestError::with_source( - RequestErrorKind::Other, - "broken pipe", - )), - } + Ok(message) } /// Create a new globally unique inbox which can be used for replies. @@ -503,8 +479,10 @@ impl Client { } /// Used for building customized requests. -#[derive(Default)] +#[derive(Debug)] pub struct Request { + client: Client, + subject: String, payload: Option, headers: Option, timeout: Option>, @@ -512,8 +490,15 @@ pub struct Request { } impl Request { - pub fn new() -> Request { - Default::default() + pub fn new(client: Client, subject: String, payload: Bytes) -> Request { + Request { + client, + subject, + payload: Some(payload), + headers: None, + timeout: None, + inbox: None, + } } /// Sets the payload of the request. If not used, empty payload will be sent. @@ -523,8 +508,7 @@ impl Request { /// # #[tokio::main] /// # async fn main() -> Result<(), async_nats::Error> { /// let client = async_nats::connect("demo.nats.io").await?; - /// let request = async_nats::Request::new().payload("data".into()); - /// client.send_request("service".into(), request).await?; + /// client.request("service".into(), "data".into()).await?; /// # Ok(()) /// # } /// ``` @@ -546,10 +530,11 @@ impl Request { /// "X-Example", /// async_nats::HeaderValue::from_str("Value").unwrap(), /// ); - /// let request = async_nats::Request::new() + /// client + /// .request("subject".into(), "data".into()) /// .headers(headers) - /// .payload("data".into()); - /// client.send_request("service".into(), request).await?; + /// .await?; + /// /// # Ok(()) /// # } /// ``` @@ -567,10 +552,11 @@ impl Request { /// # #[tokio::main] /// # async fn main() -> Result<(), async_nats::Error> { /// let client = async_nats::connect("demo.nats.io").await?; - /// let request = async_nats::Request::new() + /// client + /// .request("service".into(), "data".into()) /// .timeout(Some(std::time::Duration::from_secs(15))) - /// .payload("data".into()); - /// client.send_request("service".into(), request).await?; + /// .await?; + /// /// # Ok(()) /// # } /// ``` @@ -587,10 +573,11 @@ impl Request { /// # async fn main() -> Result<(), async_nats::Error> { /// use std::str::FromStr; /// let client = async_nats::connect("demo.nats.io").await?; - /// let request = async_nats::Request::new() + /// client + /// .request("subject".into(), "data".into()) /// .inbox("custom_inbox".into()) - /// .payload("data".into()); - /// client.send_request("service".into(), request).await?; + /// .await?; + /// /// # Ok(()) /// # } /// ``` @@ -598,6 +585,55 @@ impl Request { self.inbox = Some(inbox); self } + + async fn send(self) -> Result { + let inbox = self.inbox.unwrap_or_else(|| self.client.new_inbox()); + let mut subscriber = self.client.subscribe(inbox.clone()).await?; + let mut publish = self + .client + .publish(self.subject, self.payload.unwrap_or_else(Bytes::new)); + + if let Some(headers) = self.headers { + publish = publish.headers(headers); + } + + publish = publish.reply(inbox); + publish.into_future().await?; + + self.client + .flush() + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err)) + .await?; + + let period = self.timeout.unwrap_or(self.client.request_timeout); + let message = match period { + Some(period) => { + tokio::time::timeout(period, subscriber.next()) + .map_err(|_| RequestError::new(RequestErrorKind::TimedOut)) + .await? + } + None => subscriber.next().await, + }; + + match message { + Some(message) => { + if message.status == Some(StatusCode::NO_RESPONDERS) { + return Err(RequestError::new(RequestErrorKind::NoResponders)); + } + Ok(message) + } + None => Err(RequestError::new(RequestErrorKind::Other)), + } + } +} + +impl IntoFuture for Request { + type Output = Result; + type IntoFuture = Pin> + Send>>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.send()) + } } #[derive(Error, Debug)] diff --git a/async-nats/src/jetstream/consumer/mod.rs b/async-nats/src/jetstream/consumer/mod.rs index dc2ed42a2..b487dc6b2 100644 --- a/async-nats/src/jetstream/consumer/mod.rs +++ b/async-nats/src/jetstream/consumer/mod.rs @@ -17,13 +17,15 @@ pub mod pull; pub mod push; #[cfg(feature = "server_2_10")] use std::collections::HashMap; +use std::future::IntoFuture; use std::time::Duration; use serde::{Deserialize, Serialize}; use serde_json::json; use time::serde::rfc3339; -use super::context::RequestError; +use super::context::{RequestError, RequestErrorKind}; +use super::response::Response; use super::stream::ClusterInfo; use super::Context; use crate::error::Error; @@ -76,14 +78,38 @@ impl Consumer { pub async fn info(&mut self) -> Result<&consumer::Info, RequestError> { let subject = format!("CONSUMER.INFO.{}.{}", self.info.stream_name, self.info.name); - let info = self.context.request(subject, &json!({})).await?; - self.info = info; - Ok(&self.info) + let response: Response = self + .context + .request(subject, &json!({})) + .into_future() + .await?; + + match response { + Response::Ok(info) => { + self.info = info; + Ok(&self.info) + } + Response::Err { error } => { + Err(RequestError::with_source(RequestErrorKind::Other, error)) + } + } } async fn fetch_info(&self) -> Result { let subject = format!("CONSUMER.INFO.{}.{}", self.info.stream_name, self.info.name); - self.context.request(subject, &json!({})).await + + let response: Response = self + .context + .request(subject, &json!({})) + .into_future() + .await?; + + match response { + Response::Ok(info) => Ok(info), + Response::Err { error } => { + Err(RequestError::with_source(RequestErrorKind::Other, error)) + } + } } /// Returns cached [Info] for the [Consumer]. diff --git a/async-nats/src/jetstream/context.rs b/async-nats/src/jetstream/context.rs index 1fd009848..5d158a028 100644 --- a/async-nats/src/jetstream/context.rs +++ b/async-nats/src/jetstream/context.rs @@ -29,6 +29,7 @@ use std::borrow::Borrow; use std::fmt::Display; use std::future::IntoFuture; use std::io::ErrorKind; +use std::marker::PhantomData; use std::pin::Pin; use std::str::from_utf8; use std::task::Poll; @@ -127,13 +128,8 @@ impl Context { /// # Ok(()) /// # } /// ``` - pub async fn publish( - &self, - subject: String, - payload: Bytes, - ) -> Result { - self.send_publish(subject, Publish::build().payload(payload)) - .await + pub fn publish(&self, subject: String, payload: Bytes) -> Publish { + Publish::new(self.clone(), subject, payload) } /// Publish a message with headers to a given subject associated with a stream and returns an acknowledgment from @@ -163,67 +159,8 @@ impl Context { headers: crate::header::HeaderMap, payload: Bytes, ) -> Result { - self.send_publish(subject, Publish::build().payload(payload).headers(headers)) - .await - } - - /// Publish a message built by [Publish] and returns an acknowledgment future. - /// - /// If the stream does not exist, `no responders` error will be returned. - /// - /// # Examples - /// - /// ```no_run - /// # use async_nats::jetstream::context::Publish; - /// # #[tokio::main] - /// # async fn main() -> Result<(), async_nats::Error> { - /// let client = async_nats::connect("localhost:4222").await?; - /// let jetstream = async_nats::jetstream::new(client); - /// - /// let ack = jetstream - /// .send_publish( - /// "events".to_string(), - /// Publish::build().payload("data".into()).message_id("uuid"), - /// ) - /// .await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn send_publish( - &self, - subject: String, - publish: Publish, - ) -> Result { - let inbox = self.client.new_inbox(); - let response = self - .client - .subscribe(inbox.clone()) - .await - .map_err(|err| PublishError::with_source(PublishErrorKind::Other, err))?; - tokio::time::timeout(self.timeout, async { - if let Some(headers) = publish.headers { - self.client - .publish_with_reply_and_headers( - subject, - inbox.clone(), - headers, - publish.payload, - ) - .await - } else { - self.client - .publish_with_reply(subject, inbox.clone(), publish.payload) - .await - } - }) - .map_err(|_| PublishError::new(PublishErrorKind::TimedOut)) - .await? - .map_err(|err| PublishError::with_source(PublishErrorKind::Other, err))?; - - Ok(PublishAckFuture { - timeout: self.timeout, - subscription: response, - }) + let ack_future = self.publish(subject, payload).headers(headers).await?; + Ok(ack_future) } /// Query the server for account information @@ -811,30 +748,12 @@ impl Context { /// # Ok(()) /// # } /// ``` - pub async fn request(&self, subject: String, payload: &T) -> Result + pub fn request(&self, subject: String, payload: T) -> Request where - T: ?Sized + Serialize, + T: Sized + Serialize, V: DeserializeOwned, { - let request = serde_json::to_vec(&payload) - .map(Bytes::from) - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - - debug!("JetStream request sent: {:?}", request); - - let message = self - .client - .request(format!("{}.{}", self.prefix, subject), request) - .await; - let message = message?; - debug!( - "JetStream request response: {:?}", - from_utf8(&message.payload) - ); - let response = serde_json::from_slice(message.payload.as_ref()) - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - - Ok(response) + Request::new(self.clone(), subject, payload) } /// Creates a new object store bucket. @@ -1186,15 +1105,23 @@ impl futures::Stream for Streams<'_> { } } /// Used for building customized `publish` message. -#[derive(Default, Clone, Debug)] +#[derive(Clone, Debug)] pub struct Publish { + context: Context, + subject: String, payload: Bytes, headers: Option, } + impl Publish { /// Creates a new custom Publish struct to be used with. - pub fn build() -> Self { - Default::default() + pub(crate) fn new(context: Context, subject: String, payload: Bytes) -> Self { + Publish { + context, + subject, + payload, + headers: None, + } } /// Sets the payload for the message. @@ -1252,6 +1179,106 @@ impl Publish { } } +impl IntoFuture for Publish { + type Output = Result; + type IntoFuture = Pin> + Send>>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(std::future::IntoFuture::into_future(async move { + let inbox = self.context.client.new_inbox(); + let subscription = self + .context + .client + .subscribe(inbox.clone()) + .map_err(|err| PublishError::with_source(PublishErrorKind::Other, err)) + .await?; + + let mut publish = self + .context + .client + .publish(self.subject, self.payload) + .reply(inbox); + + if let Some(headers) = self.headers { + publish = publish.headers(headers); + } + + let timeout = self.context.timeout; + + tokio::time::timeout(timeout, publish.into_future()) + .map_err(|_| PublishError::new(PublishErrorKind::TimedOut)) + .await? + .map_err(|_| PublishError::new(PublishErrorKind::TimedOut))?; + + Ok(PublishAckFuture { + timeout, + subscription, + }) + })) + } +} + +#[derive(Debug)] +pub struct Request { + context: Context, + subject: String, + payload: T, + timeout: Option, + response_type: PhantomData, +} + +impl Request { + pub fn new(context: Context, subject: String, payload: T) -> Self { + Self { + context, + subject, + payload, + timeout: None, + response_type: PhantomData, + } + } + + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} + +impl IntoFuture for Request { + type Output = Result, RequestError>; + + type IntoFuture = Pin, RequestError>> + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let payload_result = serde_json::to_vec(&self.payload).map(Bytes::from); + + let prefix = self.context.prefix; + let client = self.context.client; + let subject = self.subject; + let timeout = self.timeout; + + Box::pin(std::future::IntoFuture::into_future(async move { + let payload = payload_result + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; + + debug!("JetStream request sent: {:?}", payload); + + let request = client.request(format!("{}.{}", prefix, subject), payload); + let request = request.timeout(timeout); + let message = request.await?; + + debug!( + "JetStream request response: {:?}", + from_utf8(&message.payload) + ); + let response = serde_json::from_slice(message.payload.as_ref()) + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; + + Ok(response) + })) + } +} + #[derive(Clone, Copy, Debug, PartialEq)] pub enum RequestErrorKind { NoResponders, diff --git a/async-nats/src/jetstream/message.rs b/async-nats/src/jetstream/message.rs index 179a353ec..a1530a457 100644 --- a/async-nats/src/jetstream/message.rs +++ b/async-nats/src/jetstream/message.rs @@ -17,6 +17,7 @@ use crate::Error; use bytes::Bytes; use futures::future::TryFutureExt; use futures::StreamExt; +use std::future::IntoFuture; use std::time::Duration; use time::OffsetDateTime; @@ -88,8 +89,8 @@ impl Message { self.context .client .publish(reply.to_string(), "".into()) - .map_err(Error::from) .await + .map_err(Error::from) } else { Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, @@ -130,8 +131,8 @@ impl Message { self.context .client .publish(reply.to_string(), kind.into()) - .map_err(Error::from) .await + .map_err(Error::from) } else { Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, @@ -375,6 +376,7 @@ impl Acker { self.context .client .publish(reply.to_string(), "".into()) + .into_future() .map_err(Error::from) .await } else { @@ -423,6 +425,7 @@ impl Acker { self.context .client .publish(reply.to_string(), kind.into()) + .into_future() .map_err(Error::from) .await } else { diff --git a/async-nats/src/jetstream/stream.rs b/async-nats/src/jetstream/stream.rs index 0283753b0..cc152b4f9 100644 --- a/async-nats/src/jetstream/stream.rs +++ b/async-nats/src/jetstream/stream.rs @@ -577,6 +577,7 @@ impl Stream { let response: Response = self .context .request(subject, &payload) + .into_future() .map_err(|err| LastRawMessageError::with_source(LastRawMessageErrorKind::Other, err)) .await?; match response { @@ -627,6 +628,7 @@ impl Stream { let response: Response = self .context .request(subject, &payload) + .into_future() .map_err(|err| match err.kind() { RequestErrorKind::TimedOut => { DeleteMessageError::new(DeleteMessageErrorKind::TimedOut) @@ -1561,6 +1563,7 @@ where .stream .context .request(request_subject, &self.inner) + .into_future() .map_err(|err| match err.kind() { RequestErrorKind::TimedOut => PurgeError::new(PurgeErrorKind::TimedOut), _ => PurgeError::with_source(PurgeErrorKind::Request, err), diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 5505e21d0..c9d1458d5 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -14,12 +14,11 @@ mod client { use async_nats::connection::State; use async_nats::header::HeaderValue; - use async_nats::{ - ConnectErrorKind, ConnectOptions, Event, Request, RequestErrorKind, ServerAddr, - }; + use async_nats::{ConnectErrorKind, ConnectOptions, Event, RequestErrorKind, ServerAddr}; use bytes::Bytes; use futures::future::join_all; use futures::stream::StreamExt; + use std::future::IntoFuture; use std::path::PathBuf; use std::str::FromStr; use std::time::Duration; @@ -125,6 +124,41 @@ mod client { assert_eq!(i, 10); } + #[tokio::test] + async fn publish_into_future_with_headers() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let mut subscriber = client.subscribe("test".into()).await.unwrap(); + + let mut headers = async_nats::HeaderMap::new(); + headers.insert("X-Test", HeaderValue::from_str("Test").unwrap()); + + client + .publish("test".into(), b"".as_ref().into()) + .headers(headers.clone()) + .await + .unwrap(); + + client.flush().await.unwrap(); + + let message = subscriber.next().await.unwrap(); + assert_eq!(message.headers.unwrap(), headers); + + let mut headers = async_nats::HeaderMap::new(); + headers.insert("X-Test", HeaderValue::from_str("Test").unwrap()); + headers.append("X-Test", "Second"); + + client + .publish("test".into(), b"".as_ref().into()) + .headers(headers.clone()) + .await + .unwrap(); + + let message = subscriber.next().await.unwrap(); + assert_eq!(message.headers.unwrap(), headers); + } + #[tokio::test] async fn publish_with_headers() { let server = nats_server::run_basic_server(); @@ -204,7 +238,9 @@ mod client { let resp = tokio::time::timeout( tokio::time::Duration::from_millis(500), - client.request("test".into(), "request".into()), + client + .request("test".into(), "request".into()) + .into_future(), ) .await .unwrap(); @@ -233,7 +269,9 @@ mod client { let err = tokio::time::timeout( tokio::time::Duration::from_millis(300), - client.request("test".into(), "request".into()), + client + .request("test".into(), "request".into()) + .into_future(), ) .await .unwrap() @@ -261,9 +299,9 @@ mod client { } }); - let request = Request::new().inbox(inbox.clone()); client - .send_request("service".into(), request) + .request("service".into(), "".into()) + .inbox(inbox) .await .unwrap(); } @@ -735,10 +773,7 @@ mod client { } }); - client - .request("request".into(), "data".into()) - .await - .unwrap(); + client.request("".into(), "data".into()).await.unwrap(); inbox_wildcard_subscription.next().await.unwrap(); } diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 5797f7d7b..5e9b437a2 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -31,12 +31,12 @@ mod jetstream { use super::*; use async_nats::connection::State; - use async_nats::header::{self, HeaderMap, NATS_MESSAGE_ID}; + use async_nats::header::{self, HeaderMap}; use async_nats::jetstream::consumer::{ self, push, AckPolicy, DeliverPolicy, Info, OrderedPullConsumer, OrderedPushConsumer, PullConsumer, PushConsumer, ReplayPolicy, }; - use async_nats::jetstream::context::{Publish, PublishErrorKind}; + use async_nats::jetstream::context::PublishErrorKind; use async_nats::jetstream::response::Response; use async_nats::jetstream::stream::{self, DiscardPolicy, StorageType}; use async_nats::jetstream::AckKind; @@ -60,11 +60,10 @@ mod jetstream { } #[tokio::test] - async fn publish_with_headers() { + async fn publish_headers() { let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = async_nats::connect(server.client_url()).await.unwrap(); let context = async_nats::jetstream::new(client); - let _stream = context .create_stream(stream::Config { name: "TEST".to_string(), @@ -73,28 +72,27 @@ mod jetstream { }) .await .unwrap(); - let headers = HeaderMap::new(); let payload = b"Hello JetStream"; let ack = context - .publish_with_headers("foo".into(), headers, payload.as_ref().into()) + .publish("foo".into(), payload.as_ref().into()) + .headers(headers) .await .unwrap() .await .unwrap(); - assert_eq!(ack.stream, "TEST"); assert_eq!(ack.sequence, 1); } #[tokio::test] - async fn publish_async() { + async fn publish_with_headers() { let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = async_nats::connect(server.client_url()).await.unwrap(); let context = async_nats::jetstream::new(client); - context + let _stream = context .create_stream(stream::Config { name: "TEST".to_string(), subjects: vec!["foo".into(), "bar".into(), "baz".into()], @@ -103,103 +101,92 @@ mod jetstream { .await .unwrap(); + let headers = HeaderMap::new(); + let payload = b"Hello JetStream"; + let ack = context - .publish("foo".to_string(), "payload".into()) + .publish_with_headers("foo".into(), headers, payload.as_ref().into()) .await - .unwrap(); - assert!(ack.await.is_ok()); - let ack = context - .publish("not_stream".to_string(), "payload".into()) + .unwrap() .await .unwrap(); - assert!(ack.await.is_err()); + + assert_eq!(ack.stream, "TEST"); + assert_eq!(ack.sequence, 1); } #[tokio::test] - async fn send_publish() { + async fn publish_control() { let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = async_nats::connect(server.client_url()).await.unwrap(); let context = async_nats::jetstream::new(client); - let mut stream = context .create_stream(stream::Config { name: "TEST".to_string(), subjects: vec!["foo".into(), "bar".into(), "baz".into()], - allow_direct: true, ..Default::default() }) .await .unwrap(); let id = "UUID".to_string(); - // Publish first message - context - .send_publish( - "foo".to_string(), - Publish::build() - .message_id(id.clone()) - .payload("data".into()), - ) - .await - .unwrap() - .await - .unwrap(); - // Publish second message, a duplicate. - context - .send_publish("foo".to_string(), Publish::build().message_id(id.clone())) - .await - .unwrap() - .await - .unwrap(); - // Check if we still have one message. + + // Publish duplicate messages + for _ in 0..3 { + context + .publish("foo".to_string(), "data".into()) + .message_id(id.clone()) + .await + .unwrap() + .await + .unwrap(); + } + let info = stream.info().await.unwrap(); assert_eq!(1, info.state.messages); - let message = stream - .direct_get_last_for_subject("foo".to_string()) - .await - .unwrap(); - assert_eq!(message.payload, bytes::Bytes::from("data")); - // Publish message with different ID and expect error. - let err = context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_message_id("BAD_ID"), - ) - .await - .unwrap() - .await - .unwrap_err() - .kind(); - assert_eq!(err, PublishErrorKind::WrongLastMessageId); - // Publish a new message with expected ID. context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_message_id(id.clone()), - ) + .publish("foo".to_string(), "data".into()) + .expected_last_message_id(id.clone()) .await .unwrap() .await .unwrap(); - // We should have now two messages. Check it. + let info = stream.info().await.unwrap(); + assert_eq!(2, info.state.messages); + + assert_eq!( + context + .publish("foo".to_string(), "data".into()) + .expected_last_message_id("invalid") + .await + .unwrap() + .await + .unwrap_err() + .kind(), + PublishErrorKind::WrongLastMessageId + ); + + let info = stream.info().await.unwrap(); + assert_eq!(2, info.state.messages); + context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_sequence(2), - ) + .publish("foo".to_string(), "data".into()) + .expected_last_sequence(2) .await .unwrap() .await .unwrap(); - // 3 messages should be there, so this should error. + + let info = stream.info().await.unwrap(); + assert_eq!(3, info.state.messages); + + // 3 messages should be there, so this should error assert_eq!( context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_sequence(2), - ) + .publish("foo".into(), "data".into()) + .expected_last_sequence(2) .await .unwrap() .await @@ -207,23 +194,23 @@ mod jetstream { .kind(), PublishErrorKind::WrongLastSequence ); - // 3 messages there, should be ok for this subject too. + context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_subject_sequence(3), - ) + .publish("bar".to_string(), "data".into()) + .expected_last_sequence(3) .await .unwrap() .await .unwrap(); - // 4 messages there, should error. + + let info = stream.info().await.unwrap(); + assert_eq!(4, info.state.messages); + + // 4 messages should be there, so this should error assert_eq!( context - .send_publish( - "foo".to_string(), - Publish::build().expected_last_subject_sequence(3), - ) + .publish("foo".into(), "data".into()) + .expected_last_sequence(3) .await .unwrap() .await @@ -232,45 +219,66 @@ mod jetstream { PublishErrorKind::WrongLastSequence ); - // Check if it works for the other subjects in the stream. - context - .send_publish( - "bar".to_string(), - Publish::build().expected_last_subject_sequence(0), - ) - .await - .unwrap() - .await - .unwrap(); - // Sequence is now 1, so this should fail. - context - .send_publish( - "bar".to_string(), - Publish::build().expected_last_subject_sequence(0), - ) - .await - .unwrap() - .await - .unwrap_err(); - // test header shorthand - assert_eq!(stream.info().await.unwrap().state.messages, 5); - context - .send_publish( - "foo".to_string(), - Publish::build().header(NATS_MESSAGE_ID, id.as_str()), - ) - .await - .unwrap() - .await - .unwrap(); - // above message should be ignored. - assert_eq!(stream.info().await.unwrap().state.messages, 5); + // check if it works for the other subjects in the stream. context - .send_publish("bar".to_string(), Publish::build().expected_stream("TEST")) + .publish("baz".into(), "data".into()) + .expected_last_subject_sequence(0) .await .unwrap() .await .unwrap(); + + // sequence is now 1, so this should error + assert_eq!( + context + .publish("baz".into(), "data".into()) + .expected_last_subject_sequence(0) + .await + .unwrap() + .await + .unwrap_err() + .kind(), + PublishErrorKind::WrongLastSequence + ); + + let info = stream.info().await.unwrap(); + assert_eq!(5, info.state.messages); + + // 5 messages should be there, so this should error + assert_eq!( + context + .publish("foo".into(), "data".into()) + .expected_last_sequence(4) + .await + .unwrap() + .await + .unwrap_err() + .kind(), + PublishErrorKind::WrongLastSequence + ); + + let subjects = ["foo", "bar", "baz"]; + for subject in subjects { + context + .publish(subject.into(), "data".into()) + .expected_stream("TEST") + .await + .unwrap() + .await + .unwrap(); + + assert_eq!( + context + .publish(subject.into(), "data".into()) + .expected_stream("INVALID") + .await + .unwrap() + .await + .unwrap_err() + .kind(), + PublishErrorKind::Other + ); + } } #[tokio::test] @@ -342,6 +350,21 @@ mod jetstream { assert!(matches!(response, Response::Err { .. })); } + #[tokio::test] + async fn request_timeout() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + let context = async_nats::jetstream::new(client); + + let response: Response = context + .request("INFO".to_string(), &()) + .timeout(Duration::from_secs(1)) + .await + .unwrap(); + + assert!(matches!(response, Response::Ok { .. })); + } + #[tokio::test] async fn create_stream() { let server = nats_server::run_server("tests/configs/jetstream.conf");