Skip to content

Commit

Permalink
super_stream_consumer new approach
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia committed Oct 22, 2024
1 parent e590d11 commit 83b3d5a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 41 deletions.
15 changes: 9 additions & 6 deletions src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
AtomicBool,
Ordering::{Relaxed, SeqCst},
},
Arc, Mutex,
Arc,
},
task::{Context, Poll},
};
Expand All @@ -32,11 +32,10 @@ use rand::{seq::SliceRandom, SeedableRng};
type FilterPredicate = Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>;

/// API for consuming RabbitMQ stream messages
#[derive(Clone)]
pub struct Consumer {
// Mandatory in case of manual offset tracking
name: Option<String>,
receiver: Arc<Mutex<Receiver<Result<Delivery, ConsumerDeliveryError>>>>,
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
internal: Arc<ConsumerInternal>,
}

Expand Down Expand Up @@ -181,7 +180,7 @@ impl ConsumerBuilder {
if response.is_ok() {
Ok(Consumer {
name: self.consumer_name,
receiver: Arc::new(Mutex::new(rx)),
receiver: rx,
internal: consumer,
})
} else {
Expand Down Expand Up @@ -235,6 +234,10 @@ impl Consumer {
}
}

pub fn get_receiver(&mut self) -> &Receiver<Result<Delivery, ConsumerDeliveryError>> {
return &self.receiver;
}

pub async fn query_offset(&self) -> Result<u64, ConsumerStoreOffsetError> {
if let Some(name) = &self.name {
self.internal
Expand All @@ -251,9 +254,9 @@ impl Consumer {
impl Stream for Consumer {
type Item = Result<Delivery, ConsumerDeliveryError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.internal.waker.register(cx.waker());
let poll = Pin::new(&mut self.receiver.lock().unwrap()).poll_recv(cx);
let poll = Pin::new(&mut self.receiver).poll_recv(cx);
match (self.is_closed(), poll.is_ready()) {
(true, false) => Poll::Ready(None),
_ => poll,
Expand Down
2 changes: 0 additions & 2 deletions src/superstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub struct DefaultSuperStreamMetadata {
impl DefaultSuperStreamMetadata {
pub async fn partitions(&mut self) -> Vec<String> {
if self.partitions.is_empty() {
println!("partition len is 0");
let response = self.client.partitions(self.super_stream.clone()).await;

self.partitions = response.unwrap().streams;
Expand Down Expand Up @@ -64,7 +63,6 @@ impl HashRoutingMurmurStrategy {
message: &Message,
metadata: &mut DefaultSuperStreamMetadata,
) -> Vec<String> {
println!("im in routes");
let mut streams: Vec<String> = Vec::new();

let key = (self.routing_extractor)(message);
Expand Down
42 changes: 28 additions & 14 deletions src/superstream_consumer.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use crate::consumer::Delivery;
use crate::error::ConsumerDeliveryError;
use crate::superstream::DefaultSuperStreamMetadata;
use crate::{error::ConsumerCreateError, Consumer, Environment};
use futures::{Stream, StreamExt};
use rabbitmq_stream_protocol::commands::subscribe::OffsetSpecification;
use std::sync::Arc;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc::{channel, Receiver};
use tokio::task;

//type FilterPredicate = Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>;

/// API for consuming RabbitMQ stream messages
#[derive(Clone)]
pub struct SuperStreamConsumer {
internal: Arc<SuperStreamConsumerInternal>,
internal: SuperStreamConsumerInternal,
}

struct SuperStreamConsumerInternal {
consumers: Vec<Consumer>,
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
}

/// Builder for [`Consumer`]
Expand All @@ -28,6 +34,7 @@ impl SuperStreamConsumerBuilder {
// Connect to the user specified node first, then look for a random replica to connect to instead.
// This is recommended for load balancing purposes.
let client = self.environment.create_client().await?;
let (tx, rx) = channel(10000);

let mut super_stream_metadata = DefaultSuperStreamMetadata {
super_stream: super_stream.to_string(),
Expand All @@ -36,21 +43,28 @@ impl SuperStreamConsumerBuilder {
routes: Vec::new(),
};
let partitions = super_stream_metadata.partitions().await;
let mut consumers: Vec<Consumer> = Vec::new();

for partition in partitions.into_iter() {
let consumer = self
println!("inside partition");
let tx_cloned = tx.clone();
let mut consumer = self
.environment
.consumer()
.offset(self.offset_specification.clone())
.build(partition.as_str())
.await
.unwrap();

consumers.push(consumer);
task::spawn(async move {
println!("inside consumer thread");
while let d = consumer.next().await.unwrap() {
println!("receving messages");
_ = tx_cloned.send(d).await;
}
});
}

let super_stream_consumer_internal = Arc::new(SuperStreamConsumerInternal { consumers });
let super_stream_consumer_internal = SuperStreamConsumerInternal { receiver: rx };

Ok(SuperStreamConsumer {
internal: super_stream_consumer_internal,
Expand All @@ -63,12 +77,12 @@ impl SuperStreamConsumerBuilder {
}
}

impl SuperStreamConsumer {
pub async fn get_consumer(&self, i: usize) -> &Consumer {
return self.internal.consumers.get(i).unwrap();
}
impl Stream for SuperStreamConsumer {
type Item = Result<Delivery, ConsumerDeliveryError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
println!("inside poll!");
Pin::new(&mut self.internal.receiver).poll_recv(cx)

pub async fn get_consumers(&mut self) -> Vec<Consumer> {
self.internal.consumers.clone()
}
}
34 changes: 15 additions & 19 deletions tests/integration/consumer_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async fn super_stream_consumer_test() {
let mut super_stream_consumer: SuperStreamConsumer = env
.env
.super_stream_consumer()
//.offset(OffsetSpecification::Next)
.offset(OffsetSpecification::First)
.build(&env.super_stream)
.await
.unwrap();
Expand All @@ -97,27 +97,23 @@ async fn super_stream_consumer_test() {
.unwrap();
}

let received_messages = Arc::new(AtomicU32::new(0));

for mut consumer in super_stream_consumer.get_consumers().await.into_iter() {
let received_messages_outer = received_messages.clone();

task::spawn(async move {
let mut inner_received_messages = received_messages_outer.clone();
while let _ = consumer.next().await.unwrap() {
let value = inner_received_messages.fetch_add(1, Ordering::Relaxed);
if value == message_count {
let handle = consumer.handle();
_ = handle.close().await;
break;
}
}
});
//let received_messages = Arc::new(AtomicU32::new(0));

println!("before looping");
while let delivery = super_stream_consumer.next().await.unwrap() {
println!("inside while delivery loop");
let d = delivery.unwrap();
println!("Got message: {:#?} from stream: {} with offset: {}",
d.message().data().map(|data| String::from_utf8(data.to_vec()).unwrap()), d.stream(),
d.offset());

//let _ = received_messages.fetch_add(1, Ordering::Relaxed);

}

sleep(Duration::from_millis(1000)).await;
//sleep(Duration::from_millis(1000)).await;

assert!(received_messages.fetch_add(1, Ordering::Relaxed) == message_count);
// assert!(received_messages.fetch_add(1, Ordering::Relaxed) == message_count);

super_stream_producer.close().await.unwrap();
}
Expand Down

0 comments on commit 83b3d5a

Please sign in to comment.