From 83b3d5abb435ab7d28a2e23ba61548ae53088cbf Mon Sep 17 00:00:00 2001 From: Daniele Palaia Date: Tue, 22 Oct 2024 15:16:21 +0200 Subject: [PATCH] super_stream_consumer new approach --- src/consumer.rs | 15 ++++++----- src/superstream.rs | 2 -- src/superstream_consumer.rs | 42 ++++++++++++++++++++---------- tests/integration/consumer_test.rs | 34 +++++++++++------------- 4 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/consumer.rs b/src/consumer.rs index d33172b..61464c2 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -6,7 +6,7 @@ use std::{ AtomicBool, Ordering::{Relaxed, SeqCst}, }, - Arc, Mutex, + Arc, }, task::{Context, Poll}, }; @@ -32,11 +32,10 @@ use rand::{seq::SliceRandom, SeedableRng}; type FilterPredicate = Option bool + Send + Sync>>; /// API for consuming RabbitMQ stream messages -#[derive(Clone)] pub struct Consumer { // Mandatory in case of manual offset tracking name: Option, - receiver: Arc>>>, + receiver: Receiver>, internal: Arc, } @@ -181,7 +180,7 @@ impl ConsumerBuilder { if response.is_ok() { Ok(Consumer { name: self.consumer_name, - receiver: Arc::new(Mutex::new(rx)), + receiver: rx, internal: consumer, }) } else { @@ -235,6 +234,10 @@ impl Consumer { } } + pub fn get_receiver(&mut self) -> &Receiver> { + return &self.receiver; + } + pub async fn query_offset(&self) -> Result { if let Some(name) = &self.name { self.internal @@ -251,9 +254,9 @@ impl Consumer { impl Stream for Consumer { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.internal.waker.register(cx.waker()); - let poll = Pin::new(&mut self.receiver.lock().unwrap()).poll_recv(cx); + let poll = Pin::new(&mut self.receiver).poll_recv(cx); match (self.is_closed(), poll.is_ready()) { (true, false) => Poll::Ready(None), _ => poll, diff --git a/src/superstream.rs b/src/superstream.rs index ff64fd8..d140c47 100644 --- a/src/superstream.rs +++ b/src/superstream.rs @@ -15,7 +15,6 @@ pub struct DefaultSuperStreamMetadata { impl DefaultSuperStreamMetadata { pub async fn partitions(&mut self) -> Vec { if self.partitions.is_empty() { - println!("partition len is 0"); let response = self.client.partitions(self.super_stream.clone()).await; self.partitions = response.unwrap().streams; @@ -64,7 +63,6 @@ impl HashRoutingMurmurStrategy { message: &Message, metadata: &mut DefaultSuperStreamMetadata, ) -> Vec { - println!("im in routes"); let mut streams: Vec = Vec::new(); let key = (self.routing_extractor)(message); diff --git a/src/superstream_consumer.rs b/src/superstream_consumer.rs index 7785355..ad83d1d 100644 --- a/src/superstream_consumer.rs +++ b/src/superstream_consumer.rs @@ -1,17 +1,23 @@ +use crate::consumer::Delivery; +use crate::error::ConsumerDeliveryError; use crate::superstream::DefaultSuperStreamMetadata; use crate::{error::ConsumerCreateError, Consumer, Environment}; +use futures::{Stream, StreamExt}; use rabbitmq_stream_protocol::commands::subscribe::OffsetSpecification; -use std::sync::Arc; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::{channel, Receiver}; +use tokio::task; + //type FilterPredicate = Option bool + Send + Sync>>; /// API for consuming RabbitMQ stream messages -#[derive(Clone)] pub struct SuperStreamConsumer { - internal: Arc, + internal: SuperStreamConsumerInternal, } struct SuperStreamConsumerInternal { - consumers: Vec, + receiver: Receiver>, } /// Builder for [`Consumer`] @@ -28,6 +34,7 @@ impl SuperStreamConsumerBuilder { // Connect to the user specified node first, then look for a random replica to connect to instead. // This is recommended for load balancing purposes. let client = self.environment.create_client().await?; + let (tx, rx) = channel(10000); let mut super_stream_metadata = DefaultSuperStreamMetadata { super_stream: super_stream.to_string(), @@ -36,10 +43,11 @@ impl SuperStreamConsumerBuilder { routes: Vec::new(), }; let partitions = super_stream_metadata.partitions().await; - let mut consumers: Vec = Vec::new(); for partition in partitions.into_iter() { - let consumer = self + println!("inside partition"); + let tx_cloned = tx.clone(); + let mut consumer = self .environment .consumer() .offset(self.offset_specification.clone()) @@ -47,10 +55,16 @@ impl SuperStreamConsumerBuilder { .await .unwrap(); - consumers.push(consumer); + task::spawn(async move { + println!("inside consumer thread"); + while let d = consumer.next().await.unwrap() { + println!("receving messages"); + _ = tx_cloned.send(d).await; + } + }); } - let super_stream_consumer_internal = Arc::new(SuperStreamConsumerInternal { consumers }); + let super_stream_consumer_internal = SuperStreamConsumerInternal { receiver: rx }; Ok(SuperStreamConsumer { internal: super_stream_consumer_internal, @@ -63,12 +77,12 @@ impl SuperStreamConsumerBuilder { } } -impl SuperStreamConsumer { - pub async fn get_consumer(&self, i: usize) -> &Consumer { - return self.internal.consumers.get(i).unwrap(); - } +impl Stream for SuperStreamConsumer { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + println!("inside poll!"); + Pin::new(&mut self.internal.receiver).poll_recv(cx) - pub async fn get_consumers(&mut self) -> Vec { - self.internal.consumers.clone() } } diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index a4f6a8c..9344820 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -82,7 +82,7 @@ async fn super_stream_consumer_test() { let mut super_stream_consumer: SuperStreamConsumer = env .env .super_stream_consumer() - //.offset(OffsetSpecification::Next) + .offset(OffsetSpecification::First) .build(&env.super_stream) .await .unwrap(); @@ -97,27 +97,23 @@ async fn super_stream_consumer_test() { .unwrap(); } - let received_messages = Arc::new(AtomicU32::new(0)); - - for mut consumer in super_stream_consumer.get_consumers().await.into_iter() { - let received_messages_outer = received_messages.clone(); - - task::spawn(async move { - let mut inner_received_messages = received_messages_outer.clone(); - while let _ = consumer.next().await.unwrap() { - let value = inner_received_messages.fetch_add(1, Ordering::Relaxed); - if value == message_count { - let handle = consumer.handle(); - _ = handle.close().await; - break; - } - } - }); + //let received_messages = Arc::new(AtomicU32::new(0)); + + println!("before looping"); + while let delivery = super_stream_consumer.next().await.unwrap() { + println!("inside while delivery loop"); + let d = delivery.unwrap(); + println!("Got message: {:#?} from stream: {} with offset: {}", + d.message().data().map(|data| String::from_utf8(data.to_vec()).unwrap()), d.stream(), + d.offset()); + + //let _ = received_messages.fetch_add(1, Ordering::Relaxed); + } - sleep(Duration::from_millis(1000)).await; + //sleep(Duration::from_millis(1000)).await; - assert!(received_messages.fetch_add(1, Ordering::Relaxed) == message_count); + // assert!(received_messages.fetch_add(1, Ordering::Relaxed) == message_count); super_stream_producer.close().await.unwrap(); }