diff --git a/src/admin.rs b/src/admin.rs index bc642a30d..69dba537b 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -403,7 +403,7 @@ fn start_poll_thread(queue: Arc, should_stop: Arc) -> J .expect("Failed to start polling thread") } -type NativeEvent = NativePtr; +pub(crate) type NativeEvent = NativePtr; unsafe impl KafkaDrop for RDKafkaEvent { const TYPE: &'static str = "event"; diff --git a/src/client.rs b/src/client.rs index c8c31ea3b..1b9f6bd1c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,19 +11,19 @@ //! [`consumer`]: crate::consumer //! [`producer`]: crate::producer -use std::convert::TryFrom; use std::error::Error; use std::ffi::{CStr, CString}; use std::mem::ManuallyDrop; -use std::os::raw::{c_char, c_void}; +use std::os::raw::c_char; use std::ptr; -use std::slice; use std::string::ToString; use std::sync::Arc; +use libc::c_void; use rdkafka_sys as rdsys; use rdkafka_sys::types::*; +use crate::admin::NativeEvent; use crate::config::{ClientConfig, NativeClientConfig, RDKafkaLogLevel}; use crate::consumer::RebalanceProtocol; use crate::error::{IsError, KafkaError, KafkaResult}; @@ -239,21 +239,6 @@ impl Client { Arc::as_ptr(&context) as *mut c_void, ) }; - unsafe { rdsys::rd_kafka_conf_set_log_cb(native_config.ptr(), Some(native_log_cb::)) }; - unsafe { - rdsys::rd_kafka_conf_set_stats_cb(native_config.ptr(), Some(native_stats_cb::)) - }; - unsafe { - rdsys::rd_kafka_conf_set_error_cb(native_config.ptr(), Some(native_error_cb::)) - }; - if C::ENABLE_REFRESH_OAUTH_TOKEN { - unsafe { - rdsys::rd_kafka_conf_set_oauthbearer_token_refresh_cb( - native_config.ptr(), - Some(native_oauth_refresh_cb::), - ) - }; - } let client_ptr = unsafe { let native_config = ManuallyDrop::new(native_config); @@ -293,6 +278,128 @@ impl Client { &self.context } + pub(crate) fn poll_event(&self, queue: &NativeQueue, timeout: Timeout) -> Option { + let event = unsafe { NativeEvent::from_ptr(queue.poll(timeout)) }; + if let Some(ev) = event { + let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) }; + match evtype { + rdsys::RD_KAFKA_EVENT_LOG => self.handle_log_event(ev.ptr()), + rdsys::RD_KAFKA_EVENT_STATS => self.handle_stats_event(ev.ptr()), + rdsys::RD_KAFKA_EVENT_ERROR => { + // rdkafka reports consumer errors via RD_KAFKA_EVENT_ERROR but producer errors gets + // embedded on the ack returned via RD_KAFKA_EVENT_DR. Hence we need to return this event + // for the consumer case in order to return the error to the user. + self.handle_error_event(ev.ptr()); + return Some(ev); + } + rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH => { + if C::ENABLE_REFRESH_OAUTH_TOKEN { + self.handle_oauth_refresh_event(ev.ptr()); + } + } + _ => { + return Some(ev); + } + } + } + None + } + + fn handle_log_event(&self, event: *mut RDKafkaEvent) { + let mut fac: *const c_char = std::ptr::null(); + let mut str_: *const c_char = std::ptr::null(); + let mut level: i32 = 0; + let result = unsafe { rdsys::rd_kafka_event_log(event, &mut fac, &mut str_, &mut level) }; + if result == 0 { + let fac = unsafe { CStr::from_ptr(fac).to_string_lossy() }; + let log_message = unsafe { CStr::from_ptr(str_).to_string_lossy() }; + self.context().log( + RDKafkaLogLevel::from_int(level), + fac.trim(), + log_message.trim(), + ); + } + } + + fn handle_stats_event(&self, event: *mut RDKafkaEvent) { + let json = unsafe { CStr::from_ptr(rdsys::rd_kafka_event_stats(event)) }; + self.context().stats_raw(json.to_bytes()); + } + + fn handle_error_event(&self, event: *mut RDKafkaEvent) { + let rdkafka_err = unsafe { rdsys::rd_kafka_event_error(event) }; + let error = KafkaError::Global(rdkafka_err.into()); + let reason = + unsafe { CStr::from_ptr(rdsys::rd_kafka_event_error_string(event)).to_string_lossy() }; + self.context().error(error, reason.trim()); + } + + fn handle_oauth_refresh_event(&self, event: *mut RDKafkaEvent) { + let oauthbearer_config = unsafe { rdsys::rd_kafka_event_config_string(event) }; + let res: Result<_, Box> = (|| { + let oauthbearer_config = match oauthbearer_config.is_null() { + true => None, + false => unsafe { Some(util::cstr_to_owned(oauthbearer_config)) }, + }; + let token_info = self + .context() + .generate_oauth_token(oauthbearer_config.as_deref())?; + let token = CString::new(token_info.token)?; + let principal_name = CString::new(token_info.principal_name)?; + Ok((token, principal_name, token_info.lifetime_ms)) + })(); + match res { + Ok((token, principal_name, lifetime_ms)) => { + let mut err_buf = ErrBuf::new(); + let code = unsafe { + rdkafka_sys::rd_kafka_oauthbearer_set_token( + self.native_ptr(), + token.as_ptr(), + lifetime_ms, + principal_name.as_ptr(), + ptr::null_mut(), + 0, + err_buf.as_mut_ptr(), + err_buf.capacity(), + ) + }; + if code == RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR { + debug!("successfully set refreshed OAuth token"); + } else { + debug!( + "failed to set refreshed OAuth token (code {:?}): {}", + code, err_buf + ); + unsafe { + rdkafka_sys::rd_kafka_oauthbearer_set_token_failure( + self.native_ptr(), + err_buf.as_mut_ptr(), + ) + }; + } + } + Err(e) => { + debug!("failed to refresh OAuth token: {}", e); + let message = match CString::new(e.to_string()) { + Ok(message) => message, + Err(e) => { + error!("error message generated while refreshing OAuth token has embedded null character: {}", e); + CString::new( + "error while refreshing OAuth token has embedded null character", + ) + .expect("known to be a valid CString") + } + }; + unsafe { + rdkafka_sys::rd_kafka_oauthbearer_set_token_failure( + self.native_ptr(), + message.as_ptr(), + ) + }; + } + } + } + /// Returns the metadata information for the specified topic, or for all topics in the cluster /// if no topic is specified. pub fn fetch_metadata>( @@ -442,6 +549,11 @@ impl Client { pub(crate) fn consumer_queue(&self) -> Option { unsafe { NativeQueue::from_ptr(rdsys::rd_kafka_queue_get_consumer(self.native_ptr())) } } + + /// Returns a NativeQueue for the main librdkafka event queue from the current client. + pub(crate) fn main_queue(&self) -> NativeQueue { + unsafe { NativeQueue::from_ptr(rdsys::rd_kafka_queue_get_main(self.native_ptr())).unwrap() } + } } pub(crate) type NativeTopic = NativePtr; @@ -471,48 +583,6 @@ impl NativeQueue { } } -pub(crate) unsafe extern "C" fn native_log_cb( - client: *const RDKafka, - level: i32, - fac: *const c_char, - buf: *const c_char, -) { - let fac = CStr::from_ptr(fac).to_string_lossy(); - let log_message = CStr::from_ptr(buf).to_string_lossy(); - - let context = &mut *(rdsys::rd_kafka_opaque(client) as *mut C); - context.log( - RDKafkaLogLevel::from_int(level), - fac.trim(), - log_message.trim(), - ); -} - -pub(crate) unsafe extern "C" fn native_stats_cb( - _conf: *mut RDKafka, - json: *mut c_char, - json_len: usize, - opaque: *mut c_void, -) -> i32 { - let context = &mut *(opaque as *mut C); - context.stats_raw(slice::from_raw_parts(json as *mut u8, json_len)); - 0 // librdkafka will free the json buffer -} - -pub(crate) unsafe extern "C" fn native_error_cb( - _client: *mut RDKafka, - err: i32, - reason: *const c_char, - opaque: *mut c_void, -) { - let err = RDKafkaRespErr::try_from(err).expect("global error not an rd_kafka_resp_err_t"); - let error = KafkaError::Global(err.into()); - let reason = CStr::from_ptr(reason).to_string_lossy(); - - let context = &mut *(opaque as *mut C); - context.error(error, reason.trim()); -} - /// A generated OAuth token and its associated metadata. /// /// When using the `OAUTHBEARER` SASL authentication method, this type is @@ -529,60 +599,6 @@ pub struct OAuthToken { pub lifetime_ms: i64, } -pub(crate) unsafe extern "C" fn native_oauth_refresh_cb( - client: *mut RDKafka, - oauthbearer_config: *const c_char, - opaque: *mut c_void, -) { - let res: Result<_, Box> = (|| { - let context = &mut *(opaque as *mut C); - let oauthbearer_config = match oauthbearer_config.is_null() { - true => None, - false => Some(util::cstr_to_owned(oauthbearer_config)), - }; - let token_info = context.generate_oauth_token(oauthbearer_config.as_deref())?; - let token = CString::new(token_info.token)?; - let principal_name = CString::new(token_info.principal_name)?; - Ok((token, principal_name, token_info.lifetime_ms)) - })(); - match res { - Ok((token, principal_name, lifetime_ms)) => { - let mut err_buf = ErrBuf::new(); - let code = rdkafka_sys::rd_kafka_oauthbearer_set_token( - client, - token.as_ptr(), - lifetime_ms, - principal_name.as_ptr(), - ptr::null_mut(), - 0, - err_buf.as_mut_ptr(), - err_buf.capacity(), - ); - if code == RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR { - debug!("successfully set refreshed OAuth token"); - } else { - debug!( - "failed to set refreshed OAuth token (code {:?}): {}", - code, err_buf - ); - rdkafka_sys::rd_kafka_oauthbearer_set_token_failure(client, err_buf.as_mut_ptr()); - } - } - Err(e) => { - debug!("failed to refresh OAuth token: {}", e); - let message = match CString::new(e.to_string()) { - Ok(message) => message, - Err(e) => { - error!("error message generated while refreshing OAuth token has embedded null character: {}", e); - CString::new("error while refreshing OAuth token has embedded null character") - .expect("known to be a valid CString") - } - }; - rdkafka_sys::rd_kafka_oauthbearer_set_token_failure(client, message.as_ptr()); - } - } -} - #[cfg(test)] mod tests { // Just call everything to test there no panics by default, behavior diff --git a/src/consumer/base_consumer.rs b/src/consumer/base_consumer.rs index ee03b906b..08ec51b78 100644 --- a/src/consumer/base_consumer.rs +++ b/src/consumer/base_consumer.rs @@ -1,16 +1,17 @@ //! Low-level consumers. -use std::cmp; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::mem::ManuallyDrop; use std::os::raw::c_void; use std::ptr; use std::sync::Arc; +use std::time::{Duration, Instant}; +use log::{error, warn}; use rdkafka_sys as rdsys; use rdkafka_sys::types::*; -use crate::client::{Client, NativeClient, NativeQueue}; +use crate::client::{Client, NativeQueue}; use crate::config::{ ClientConfig, FromClientConfig, FromClientConfigAndContext, NativeClientConfig, }; @@ -26,41 +27,6 @@ use crate::metadata::Metadata; use crate::topic_partition_list::{Offset, TopicPartitionList}; use crate::util::{cstr_to_owned, NativePtr, Timeout}; -pub(crate) unsafe extern "C" fn native_commit_cb( - _conf: *mut RDKafka, - err: RDKafkaRespErr, - offsets: *mut RDKafkaTopicPartitionList, - opaque_ptr: *mut c_void, -) { - let context = &mut *(opaque_ptr as *mut C); - let commit_error = if err.is_error() { - Err(KafkaError::ConsumerCommit(err.into())) - } else { - Ok(()) - }; - if offsets.is_null() { - let tpl = TopicPartitionList::new(); - context.commit_callback(commit_error, &tpl); - } else { - let tpl = ManuallyDrop::new(TopicPartitionList::from_ptr(offsets)); - context.commit_callback(commit_error, &tpl); - } -} - -/// Native rebalance callback. This callback will run on every rebalance, and it will call the -/// rebalance method defined in the current `Context`. -unsafe extern "C" fn native_rebalance_cb( - rk: *mut RDKafka, - err: RDKafkaRespErr, - native_tpl: *mut RDKafkaTopicPartitionList, - opaque_ptr: *mut c_void, -) { - let context = &mut *(opaque_ptr as *mut C); - let native_client = ManuallyDrop::new(NativeClient::from_ptr(rk)); - let mut tpl = ManuallyDrop::new(TopicPartitionList::from_ptr(native_tpl)); - context.rebalance(&native_client, err, &mut tpl); -} - /// A low-level consumer that requires manual polling. /// /// This consumer must be periodically polled to make progress on rebalancing, @@ -70,7 +36,8 @@ where C: ConsumerContext, { client: Client, - main_queue_min_poll_interval: Timeout, + queue: NativeQueue, + group_id: Option, } impl FromClientConfig for BaseConsumer { @@ -96,57 +63,51 @@ where context: C, ) -> KafkaResult> { unsafe { - rdsys::rd_kafka_conf_set_rebalance_cb( - native_config.ptr(), - Some(native_rebalance_cb::), - ); - rdsys::rd_kafka_conf_set_offset_commit_cb( + rdsys::rd_kafka_conf_set_events( native_config.ptr(), - Some(native_commit_cb::), - ); - } - let main_queue_min_poll_interval = context.main_queue_min_poll_interval(); + rdsys::RD_KAFKA_EVENT_REBALANCE + | rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT + | rdsys::RD_KAFKA_EVENT_STATS + | rdsys::RD_KAFKA_EVENT_ERROR + | rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH, + ) + }; let client = Client::new( config, native_config, RDKafkaType::RD_KAFKA_CONSUMER, context, )?; + + let group_id = config.get("group.id").map(|s| s.to_string()); + // If a group.id is not specified, we won't redirect the main queue to the consumer queue, + // allowing continued use of the consumer for fetching metadata and watermarks without the + // need to specify a group.id + let queue = if group_id.is_some() { + // Redirect rdkafka's main queue to the consumer queue so that we only need to listen + // to the consumer queue to observe events like rebalancings and stats. + unsafe { rdsys::rd_kafka_poll_set_consumer(client.native_ptr()) }; + client.consumer_queue().ok_or_else(|| { + KafkaError::ClientCreation("rdkafka consumer queue not available".to_string()) + })? + } else { + client.main_queue() + }; + Ok(BaseConsumer { client, - main_queue_min_poll_interval, + queue, + group_id, }) } - /// Polls the consumer for messages and returns a pointer to the native rdkafka-sys struct. - /// This method is for internal use only. Use poll instead. - pub(crate) fn poll_raw(&self, mut timeout: Timeout) -> Option> { - loop { - unsafe { rdsys::rd_kafka_poll(self.client.native_ptr(), 0) }; - let op_timeout = cmp::min(timeout, self.main_queue_min_poll_interval); - let message_ptr = unsafe { - NativePtr::from_ptr(rdsys::rd_kafka_consumer_poll( - self.client.native_ptr(), - op_timeout.as_millis(), - )) - }; - if let Some(message_ptr) = message_ptr { - break Some(message_ptr); - } - if op_timeout >= timeout { - break None; - } - timeout -= op_timeout; - } - } - /// Polls the consumer for new messages. /// /// It won't block for more than the specified timeout. Use zero `Duration` for non-blocking /// call. With no timeout it blocks until an event is received. /// /// This method should be called at regular intervals, even if no message is expected, - /// to serve any queued callbacks waiting to be called. This is especially important for + /// to serve any queued events waiting to be handled. This is especially important for /// automatic consumer rebalance, as the rebalance function will be executed by the thread /// calling the poll() function. /// @@ -154,8 +115,130 @@ where /// /// The returned message lives in the memory of the consumer and cannot outlive it. pub fn poll>(&self, timeout: T) -> Option>> { - self.poll_raw(timeout.into()) - .map(|ptr| unsafe { BorrowedMessage::from_consumer(ptr, self) }) + self.poll_queue(self.get_queue(), timeout) + } + + pub(crate) fn poll_queue>( + &self, + queue: &NativeQueue, + timeout: T, + ) -> Option>> { + let now = Instant::now(); + let mut timeout = timeout.into(); + let min_poll_interval = self.context().main_queue_min_poll_interval(); + loop { + let op_timeout = std::cmp::min(timeout, min_poll_interval); + let maybe_event = self.client().poll_event(queue, op_timeout); + if let Some(event) = maybe_event { + let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) }; + match evtype { + rdsys::RD_KAFKA_EVENT_FETCH => { + if let Some(result) = self.handle_fetch_event(event) { + return Some(result); + } + } + rdsys::RD_KAFKA_EVENT_ERROR => { + if let Some(err) = self.handle_error_event(event) { + return Some(Err(err)); + } + } + rdsys::RD_KAFKA_EVENT_REBALANCE => { + self.handle_rebalance_event(event); + } + rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => { + self.handle_offset_commit_event(event); + } + _ => { + let buf = unsafe { + let evname = rdsys::rd_kafka_event_name(event.ptr()); + CStr::from_ptr(evname).to_bytes() + }; + let evname = String::from_utf8(buf.to_vec()).unwrap(); + warn!("Ignored event '{}' on consumer poll", evname); + } + } + } + + timeout = timeout.saturating_sub(now.elapsed()); + if timeout.is_zero() { + return None; + } + } + } + + fn handle_fetch_event( + &self, + event: NativePtr, + ) -> Option>> { + unsafe { + NativePtr::from_ptr(rdsys::rd_kafka_event_message_next(event.ptr()) as *mut _) + .map(|ptr| BorrowedMessage::from_client(ptr, Arc::new(event), self.client())) + } + } + + fn handle_rebalance_event(&self, event: NativePtr) { + let err = unsafe { rdsys::rd_kafka_event_error(event.ptr()) }; + match err { + rdsys::rd_kafka_resp_err_t::RD_KAFKA_RESP_ERR__ASSIGN_PARTITIONS + | rdsys::rd_kafka_resp_err_t::RD_KAFKA_RESP_ERR__REVOKE_PARTITIONS => { + let tpl = unsafe { + let native_tpl = rdsys::rd_kafka_event_topic_partition_list(event.ptr()); + TopicPartitionList::from_ptr(native_tpl) + }; + // The TPL is owned by the Event and will be destroyed when the event is destroyed. + // Dropping it here will lead to double free. + let mut tpl = ManuallyDrop::new(tpl); + self.context() + .rebalance(self.client.native_client(), err, &mut tpl); + } + _ => { + let buf = unsafe { + let err_name = + rdsys::rd_kafka_err2name(rdsys::rd_kafka_event_error(event.ptr())); + CStr::from_ptr(err_name).to_bytes() + }; + let err = String::from_utf8(buf.to_vec()).unwrap(); + warn!("invalid rebalance event: {:?}", err); + } + } + } + + fn handle_offset_commit_event(&self, event: NativePtr) { + let err = unsafe { rdsys::rd_kafka_event_error(event.ptr()) }; + let commit_error = if err.is_error() { + Err(KafkaError::ConsumerCommit(err.into())) + } else { + Ok(()) + }; + + let offsets = unsafe { rdsys::rd_kafka_event_topic_partition_list(event.ptr()) }; + if offsets.is_null() { + let tpl = TopicPartitionList::new(); + self.context().commit_callback(commit_error, &tpl); + } else { + // The TPL is owned by the Event and will be destroyed when the event is destroyed. + // Dropping it here will lead to double free. + let tpl = ManuallyDrop::new(unsafe { TopicPartitionList::from_ptr(offsets) }); + self.context().commit_callback(commit_error, &tpl); + } + } + + fn handle_error_event(&self, event: NativePtr) -> Option { + let rdkafka_err = unsafe { rdsys::rd_kafka_event_error(event.ptr()) }; + if rdkafka_err.is_error() { + if rdkafka_err == rdsys::rd_kafka_resp_err_t::RD_KAFKA_RESP_ERR__PARTITION_EOF { + let tp_ptr = unsafe { rdsys::rd_kafka_event_topic_partition(event.ptr()) }; + let partition = unsafe { (*tp_ptr).partition }; + unsafe { rdsys::rd_kafka_topic_partition_destroy(tp_ptr) }; + Some(KafkaError::PartitionEOF(partition)) + } else if unsafe { rdsys::rd_kafka_event_error_is_fatal(event.ptr()) } != 0 { + Some(KafkaError::MessageConsumptionFatal(rdkafka_err.into())) + } else { + Some(KafkaError::MessageConsumption(rdkafka_err.into())) + } + } else { + None + } } /// Returns an iterator over the available messages. @@ -202,6 +285,10 @@ where Iter(self) } + pub(crate) fn get_queue(&self) -> &NativeQueue { + &self.queue + } + /// Splits messages for the specified partition into their own queue. /// /// If the `topic` or `partition` is invalid, returns `None`. @@ -245,6 +332,27 @@ where PartitionQueue::new(self.clone(), queue) }) } + + /// Close the queue used by a consumer. + /// Only exposed for advanced usage of this API and should not be used under normal circumstances. + pub fn close_queue(&self) -> KafkaResult<()> { + let err = unsafe { + RDKafkaError::from_ptr(rdsys::rd_kafka_consumer_close_queue( + self.client.native_ptr(), + self.queue.ptr(), + )) + }; + if err.is_error() { + Err(KafkaError::ConsumerQueueClose(err.code())) + } else { + Ok(()) + } + } + + /// Returns true if the consumer is closed, else false. + pub fn closed(&self) -> bool { + unsafe { rdsys::rd_kafka_consumer_closed(self.client.native_ptr()) == 1 } + } } impl Consumer for BaseConsumer @@ -607,8 +715,16 @@ where C: ConsumerContext, { fn drop(&mut self) { - trace!("Destroying consumer: {:?}", self.client.native_ptr()); // TODO: fix me (multiple executions ?) - unsafe { rdsys::rd_kafka_consumer_close(self.client.native_ptr()) }; + trace!("Destroying consumer: {:?}", self.client.native_ptr()); + if self.group_id.is_some() { + if let Err(err) = self.close_queue() { + error!("Failed to close consumer queue on drop: {}", err); + } else { + while !self.closed() { + self.poll(Duration::from_millis(100)); + } + } + } trace!("Consumer destroyed: {:?}", self.client.native_ptr()); } } @@ -654,7 +770,7 @@ where C: ConsumerContext, { consumer: Arc>, - queue: NativeQueue, + pub(crate) queue: NativeQueue, nonempty_callback: Option>>, } @@ -677,15 +793,9 @@ where /// /// Remember that you must also call [`BaseConsumer::poll`] on the /// associated consumer regularly, even if no messages are expected, to - /// serve callbacks. + /// serve events. pub fn poll>(&self, timeout: T) -> Option>> { - unsafe { - NativePtr::from_ptr(rdsys::rd_kafka_consume_queue( - self.queue.ptr(), - timeout.into().as_millis(), - )) - } - .map(|ptr| unsafe { BorrowedMessage::from_consumer(ptr, &self.consumer) }) + self.consumer.poll_queue(&self.queue, timeout) } /// Sets a callback that will be invoked whenever the queue becomes diff --git a/src/consumer/mod.rs b/src/consumer/mod.rs index 65bbd215a..fb3ff0460 100644 --- a/src/consumer/mod.rs +++ b/src/consumer/mod.rs @@ -113,12 +113,12 @@ pub trait ConsumerContext: ClientContext { fn commit_callback(&self, result: KafkaResult<()>, offsets: &TopicPartitionList) {} /// Returns the minimum interval at which to poll the main queue, which - /// services the logging, stats, and error callbacks. + /// services the logging, stats, and error events. /// /// The main queue is polled once whenever [`BaseConsumer::poll`] is called. /// If `poll` is called with a timeout that is larger than this interval, /// then the main queue will be polled at that interval while the consumer - /// queue is blocked. + /// queue is blocked. This allows serving events while there are no messages. /// /// For example, if the main queue's minimum poll interval is 200ms and /// `poll` is called with a timeout of 1s, then `poll` may block for up to diff --git a/src/consumer/stream_consumer.rs b/src/consumer/stream_consumer.rs index 0c959f329..5a7f60552 100644 --- a/src/consumer/stream_consumer.rs +++ b/src/consumer/stream_consumer.rs @@ -1,6 +1,5 @@ //! High-level consumers with a [`Stream`](futures_util::Stream) interface. -use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::c_void; use std::pin::Pin; @@ -21,7 +20,7 @@ use rdkafka_sys::types::*; use crate::client::{Client, NativeQueue}; use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext}; -use crate::consumer::base_consumer::BaseConsumer; +use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue}; use crate::consumer::{ CommitMode, Consumer, ConsumerContext, ConsumerGroupMetadata, DefaultConsumerContext, RebalanceProtocol, @@ -31,7 +30,7 @@ use crate::groups::GroupList; use crate::message::BorrowedMessage; use crate::metadata::Metadata; use crate::topic_partition_list::{Offset, TopicPartitionList}; -use crate::util::{AsyncRuntime, DefaultRuntime, NativePtr, Timeout}; +use crate::util::{AsyncRuntime, DefaultRuntime, Timeout}; unsafe extern "C" fn native_message_queue_nonempty_cb(_: *mut RDKafka, opaque_ptr: *mut c_void) { let wakers = &*(opaque_ptr as *const WakerSlab); @@ -89,31 +88,50 @@ impl WakerSlab { /// A stream of messages from a [`StreamConsumer`]. /// /// See the documentation of [`StreamConsumer::stream`] for details. -pub struct MessageStream<'a> { +pub struct MessageStream<'a, C: ConsumerContext> { wakers: &'a WakerSlab, - queue: &'a NativeQueue, + consumer: &'a BaseConsumer, + partition_queue: Option<&'a NativeQueue>, slot: usize, } -impl<'a> MessageStream<'a> { - fn new(wakers: &'a WakerSlab, queue: &'a NativeQueue) -> MessageStream<'a> { +impl<'a, C: ConsumerContext> MessageStream<'a, C> { + fn new(wakers: &'a WakerSlab, consumer: &'a BaseConsumer) -> MessageStream<'a, C> { + Self::new_with_optional_partition_queue(wakers, consumer, None) + } + + fn new_with_partition_queue( + wakers: &'a WakerSlab, + consumer: &'a BaseConsumer, + partition_queue: &'a NativeQueue, + ) -> MessageStream<'a, C> { + Self::new_with_optional_partition_queue(wakers, consumer, Some(partition_queue)) + } + + fn new_with_optional_partition_queue( + wakers: &'a WakerSlab, + consumer: &'a BaseConsumer, + partition_queue: Option<&'a NativeQueue>, + ) -> MessageStream<'a, C> { let slot = wakers.register(); MessageStream { wakers, - queue, + consumer, + partition_queue, slot, } } fn poll(&self) -> Option>> { - unsafe { - NativePtr::from_ptr(rdsys::rd_kafka_consume_queue(self.queue.ptr(), 0)) - .map(|p| BorrowedMessage::from_consumer(p, self.queue)) + if let Some(queue) = self.partition_queue { + self.consumer.poll_queue(queue, Duration::ZERO) + } else { + self.consumer.poll(Duration::ZERO) } } } -impl<'a> Stream for MessageStream<'a> { +impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> { type Item = KafkaResult>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -140,7 +158,7 @@ impl<'a> Stream for MessageStream<'a> { } } -impl<'a> Drop for MessageStream<'a> { +impl<'a, C: ConsumerContext> Drop for MessageStream<'a, C> { fn drop(&mut self) { self.wakers.unregister(self.slot); } @@ -165,8 +183,7 @@ pub struct StreamConsumer where C: ConsumerContext, { - queue: NativeQueue, // queue must be dropped before the base to avoid deadlock - base: BaseConsumer, + base: Arc>, wakers: Arc, _shutdown_trigger: oneshot::Sender<()>, _runtime: PhantomData, @@ -197,19 +214,11 @@ where Duration::from_millis(millis) }; - let base = BaseConsumer::new(config, native_config, context)?; + let base = Arc::new(BaseConsumer::new(config, native_config, context)?); let native_ptr = base.client().native_ptr() as usize; - // Redirect rdkafka's main queue to the consumer queue so that we only - // need to listen to the consumer queue to observe events like - // rebalancings and stats. - unsafe { rdsys::rd_kafka_poll_set_consumer(base.client().native_ptr()) }; - - let queue = base.client().consumer_queue().ok_or_else(|| { - KafkaError::ClientCreation("librdkafka failed to create consumer queue".into()) - })?; let wakers = Arc::new(WakerSlab::new()); - unsafe { enable_nonempty_callback(&queue, &wakers) } + unsafe { enable_nonempty_callback(base.get_queue(), &wakers) } // We need to make sure we poll the consumer at least once every max // poll interval, *unless* the processing task has wedged. To accomplish @@ -241,7 +250,6 @@ where Ok(StreamConsumer { base, wakers, - queue, _shutdown_trigger: shutdown_trigger, _runtime: PhantomData, }) @@ -264,8 +272,8 @@ where /// /// If you want multiple independent views of a Kafka topic, create multiple /// consumers, not multiple message streams. - pub fn stream(&self) -> MessageStream<'_> { - MessageStream::new(&self.wakers, &self.queue) + pub fn stream(&self) -> MessageStream<'_, C> { + MessageStream::new(&self.wakers, &self.base) } /// Receives the next message from the stream. @@ -308,7 +316,7 @@ where /// `StreamConsumer::recv`. /// /// You must periodically await `StreamConsumer::recv`, even if no messages - /// are expected, to serve callbacks. Consider using a background task like: + /// are expected, to serve events. Consider using a background task like: /// /// ``` /// # use rdkafka::consumer::StreamConsumer; @@ -334,29 +342,17 @@ where topic: &str, partition: i32, ) -> Option> { - let topic = match CString::new(topic) { - Ok(topic) => topic, - Err(_) => return None, - }; - let queue = unsafe { - NativeQueue::from_ptr(rdsys::rd_kafka_queue_get_partition( - self.base.client().native_ptr(), - topic.as_ptr(), - partition, - )) - }; - queue.map(|queue| { - let wakers = Arc::new(WakerSlab::new()); - unsafe { - rdsys::rd_kafka_queue_forward(queue.ptr(), ptr::null_mut()); - enable_nonempty_callback(&queue, &wakers); - } - StreamPartitionQueue { - queue, - wakers, - _consumer: self.clone(), - } - }) + self.base + .split_partition_queue(topic, partition) + .map(|queue| { + let wakers = Arc::new(WakerSlab::new()); + unsafe { enable_nonempty_callback(&queue.queue, &wakers) }; + StreamPartitionQueue { + queue, + wakers, + _consumer: self.clone(), + } + }) } } @@ -551,7 +547,7 @@ pub struct StreamPartitionQueue where C: ConsumerContext, { - queue: NativeQueue, + queue: PartitionQueue, wakers: Arc, _consumer: Arc>, } @@ -572,8 +568,12 @@ where /// /// If you want multiple independent views of a Kafka partition, create /// multiple consumers, not multiple partition streams. - pub fn stream(&self) -> MessageStream<'_> { - MessageStream::new(&self.wakers, &self.queue) + pub fn stream(&self) -> MessageStream<'_, C> { + MessageStream::new_with_partition_queue( + &self.wakers, + &self._consumer.base, + &self.queue.queue, + ) } /// Receives the next message from the stream. @@ -612,6 +612,6 @@ where C: ConsumerContext, { fn drop(&mut self) { - unsafe { disable_nonempty_callback(&self.queue) } + unsafe { disable_nonempty_callback(&self.queue.queue) } } } diff --git a/src/error.rs b/src/error.rs index 72e364479..312a6bb65 100644 --- a/src/error.rs +++ b/src/error.rs @@ -147,6 +147,8 @@ pub enum KafkaError { ClientCreation(String), /// Consumer commit failed. ConsumerCommit(RDKafkaErrorCode), + /// Consumer queue close failed. + ConsumerQueueClose(RDKafkaErrorCode), /// Flushing failed Flush(RDKafkaErrorCode), /// Global error. @@ -155,6 +157,8 @@ pub enum KafkaError { GroupListFetch(RDKafkaErrorCode), /// Message consumption failed. MessageConsumption(RDKafkaErrorCode), + /// Message consumption failed with fatal error. + MessageConsumptionFatal(RDKafkaErrorCode), /// Message production error. MessageProduction(RDKafkaErrorCode), /// Metadata fetch error. @@ -204,6 +208,9 @@ impl fmt::Debug for KafkaError { KafkaError::ConsumerCommit(err) => { write!(f, "KafkaError (Consumer commit error: {})", err) } + KafkaError::ConsumerQueueClose(err) => { + write!(f, "KafkaError (Consumer queue close error: {})", err) + } KafkaError::Flush(err) => write!(f, "KafkaError (Flush error: {})", err), KafkaError::Global(err) => write!(f, "KafkaError (Global error: {})", err), KafkaError::GroupListFetch(err) => { @@ -212,6 +219,9 @@ impl fmt::Debug for KafkaError { KafkaError::MessageConsumption(err) => { write!(f, "KafkaError (Message consumption error: {})", err) } + KafkaError::MessageConsumptionFatal(err) => { + write!(f, "(Fatal) KafkaError (Message consumption error: {})", err) + } KafkaError::MessageProduction(err) => { write!(f, "KafkaError (Message production error: {})", err) } @@ -255,10 +265,14 @@ impl fmt::Display for KafkaError { } KafkaError::ClientCreation(ref err) => write!(f, "Client creation error: {}", err), KafkaError::ConsumerCommit(err) => write!(f, "Consumer commit error: {}", err), + KafkaError::ConsumerQueueClose(err) => write!(f, "Consumer queue close error: {}", err), KafkaError::Flush(err) => write!(f, "Flush error: {}", err), KafkaError::Global(err) => write!(f, "Global error: {}", err), KafkaError::GroupListFetch(err) => write!(f, "Group list fetch error: {}", err), KafkaError::MessageConsumption(err) => write!(f, "Message consumption error: {}", err), + KafkaError::MessageConsumptionFatal(err) => { + write!(f, "(Fatal) Message consumption error: {}", err) + } KafkaError::MessageProduction(err) => write!(f, "Message production error: {}", err), KafkaError::MetadataFetch(err) => write!(f, "Meta data fetch error: {}", err), KafkaError::NoMessageReceived => { @@ -288,10 +302,12 @@ impl Error for KafkaError { KafkaError::ClientConfig(..) => None, KafkaError::ClientCreation(_) => None, KafkaError::ConsumerCommit(err) => Some(err), + KafkaError::ConsumerQueueClose(err) => Some(err), KafkaError::Flush(err) => Some(err), KafkaError::Global(err) => Some(err), KafkaError::GroupListFetch(err) => Some(err), KafkaError::MessageConsumption(err) => Some(err), + KafkaError::MessageConsumptionFatal(err) => Some(err), KafkaError::MessageProduction(err) => Some(err), KafkaError::MetadataFetch(err) => Some(err), KafkaError::NoMessageReceived => None, @@ -327,10 +343,12 @@ impl KafkaError { KafkaError::ClientConfig(..) => None, KafkaError::ClientCreation(_) => None, KafkaError::ConsumerCommit(err) => Some(*err), + KafkaError::ConsumerQueueClose(err) => Some(*err), KafkaError::Flush(err) => Some(*err), KafkaError::Global(err) => Some(*err), KafkaError::GroupListFetch(err) => Some(*err), KafkaError::MessageConsumption(err) => Some(*err), + KafkaError::MessageConsumptionFatal(err) => Some(*err), KafkaError::MessageProduction(err) => Some(*err), KafkaError::MetadataFetch(err) => Some(*err), KafkaError::NoMessageReceived => None, diff --git a/src/message.rs b/src/message.rs index 0f47baebe..76bac9c39 100644 --- a/src/message.rs +++ b/src/message.rs @@ -6,11 +6,13 @@ use std::marker::PhantomData; use std::os::raw::c_void; use std::ptr; use std::str; +use std::sync::Arc; use std::time::SystemTime; use rdkafka_sys as rdsys; use rdkafka_sys::types::*; +use crate::admin::NativeEvent; use crate::error::{IsError, KafkaError, KafkaResult}; use crate::util::{self, millis_to_epoch, KafkaDrop, NativePtr}; @@ -306,17 +308,26 @@ impl Headers for BorrowedHeaders { /// [`detach`](BorrowedMessage::detach) method. pub struct BorrowedMessage<'a> { ptr: NativePtr, + _event: Arc, _owner: PhantomData<&'a u8>, } +// When using the Event API, messages must not be freed with rd_kafka_message_destroy +unsafe extern "C" fn no_op(_: *mut RDKafkaMessage) {} + unsafe impl KafkaDrop for RDKafkaMessage { const TYPE: &'static str = "message"; - const DROP: unsafe extern "C" fn(*mut Self) = rdsys::rd_kafka_message_destroy; + const DROP: unsafe extern "C" fn(*mut Self) = no_op; } impl<'a> fmt::Debug for BorrowedMessage<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Message {{ ptr: {:?} }}", self.ptr()) + write!( + f, + "Message {{ ptr: {:?}, event_ptr: {:?} }}", + self.ptr(), + self._event.ptr() + ) } } @@ -327,14 +338,15 @@ impl<'a> BorrowedMessage<'a> { /// should only be used with messages coming from consumers. If the message /// contains an error, only the error is returned and the message structure /// is freed. - pub(crate) unsafe fn from_consumer( + pub(crate) unsafe fn from_client( ptr: NativePtr, - _consumer: &'a C, + event: Arc, + _client: &'a C, ) -> KafkaResult> { if ptr.err.is_error() { let err = match ptr.err { rdsys::rd_kafka_resp_err_t::RD_KAFKA_RESP_ERR__PARTITION_EOF => { - KafkaError::PartitionEOF((*ptr).partition) + KafkaError::PartitionEOF(ptr.partition) } e => KafkaError::MessageConsumption(e.into()), }; @@ -342,22 +354,24 @@ impl<'a> BorrowedMessage<'a> { } else { Ok(BorrowedMessage { ptr, + _event: event, _owner: PhantomData, }) } } /// Creates a new `BorrowedMessage` that wraps the native Kafka message - /// pointer returned by the delivery callback of a producer. The lifetime of - /// the message will be bound to the lifetime of the reference passed as - /// parameter. This method should only be used with messages coming from the - /// delivery callback. The message will not be freed in any circumstance. - pub(crate) unsafe fn from_dr_callback( + /// pointer returned via the delivery report event. The lifetime of + /// the message will be bound to the lifetime of the client passed as + /// parameter. + pub(crate) unsafe fn from_dr_event( ptr: *mut RDKafkaMessage, - _owner: &'a O, + event: Arc, + _client: &'a C, ) -> DeliveryResult<'a> { let borrowed_message = BorrowedMessage { ptr: NativePtr::from_ptr(ptr).unwrap(), + _event: event, _owner: PhantomData, }; if (*ptr).err.is_error() { diff --git a/src/producer/base_producer.rs b/src/producer/base_producer.rs index 48acd925d..886d5bdee 100644 --- a/src/producer/base_producer.rs +++ b/src/producer/base_producer.rs @@ -57,7 +57,7 @@ use rdkafka_sys as rdsys; use rdkafka_sys::rd_kafka_vtype_t::*; use rdkafka_sys::types::*; -use crate::client::Client; +use crate::client::{Client, NativeQueue}; use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext}; use crate::consumer::ConsumerGroupMetadata; use crate::error::{IsError, KafkaError, KafkaResult, RDKafkaError}; @@ -67,33 +67,12 @@ use crate::producer::{ DefaultProducerContext, Partitioner, Producer, ProducerContext, PurgeConfig, }; use crate::topic_partition_list::TopicPartitionList; -use crate::util::{IntoOpaque, Timeout}; +use crate::util::{IntoOpaque, NativePtr, Timeout}; pub use crate::message::DeliveryResult; use super::NoCustomPartitioner; -/// Callback that gets called from librdkafka every time a message succeeds or fails to be -/// delivered. -unsafe extern "C" fn delivery_cb>( - _client: *mut RDKafka, - msg: *const RDKafkaMessage, - opaque: *mut c_void, -) { - let producer_context = &mut *(opaque as *mut C); - let delivery_opaque = C::DeliveryOpaque::from_ptr((*msg)._private); - let owner = 42u8; - // Wrap the message pointer into a BorrowedMessage that will only live for the body of this - // function. - let delivery_result = BorrowedMessage::from_dr_callback(msg as *mut RDKafkaMessage, &owner); - trace!("Delivery event received: {:?}", delivery_result); - producer_context.delivery(&delivery_result, delivery_opaque); - match delivery_result { - // Do not free the message, librdkafka will do it for us - Ok(message) | Err((_, message)) => mem::forget(message), - } -} - // // ********** BASE PRODUCER ********** // @@ -294,7 +273,13 @@ where } unsafe { - rdsys::rd_kafka_conf_set_dr_msg_cb(native_config.ptr(), Some(delivery_cb::)) + rdsys::rd_kafka_conf_set_events( + native_config.ptr(), + rdsys::RD_KAFKA_EVENT_DR + | rdsys::RD_KAFKA_EVENT_STATS + | rdsys::RD_KAFKA_EVENT_ERROR + | rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH, + ) }; let client = Client::new_context_arc( config, @@ -351,7 +336,9 @@ where C: ProducerContext, { client: Client, + queue: NativeQueue, _partitioner: PhantomData, + min_poll_interval: Timeout, } impl BaseProducer @@ -361,18 +348,58 @@ where { /// Creates a base producer starting from a Client. fn from_client(client: Client) -> BaseProducer { + let queue = client.main_queue(); BaseProducer { client, + queue, _partitioner: PhantomData, + min_poll_interval: Timeout::After(Duration::from_millis(100)), } } - /// Polls the producer, returning the number of events served. + /// Polls the producer /// /// Regular calls to `poll` are required to process the events and execute /// the message delivery callbacks. - pub fn poll>(&self, timeout: T) -> i32 { - unsafe { rdsys::rd_kafka_poll(self.native_ptr(), timeout.into().as_millis()) } + pub fn poll>(&self, timeout: T) { + let event = self.client().poll_event(&self.queue, timeout.into()); + if let Some(ev) = event { + let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) }; + match evtype { + rdsys::RD_KAFKA_EVENT_DR => self.handle_delivery_report_event(ev), + _ => { + let buf = unsafe { + let evname = rdsys::rd_kafka_event_name(ev.ptr()); + CStr::from_ptr(evname).to_bytes() + }; + let evname = String::from_utf8(buf.to_vec()).unwrap(); + warn!("Ignored event '{}' on base producer poll", evname); + } + } + } + } + + fn handle_delivery_report_event(&self, event: NativePtr) { + let max_messages = unsafe { rdsys::rd_kafka_event_message_count(event.ptr()) }; + let messages: Vec<*const RDKafkaMessage> = Vec::with_capacity(max_messages); + + let mut messages = mem::ManuallyDrop::new(messages); + let messages = unsafe { + let msgs_cnt = rdsys::rd_kafka_event_message_array( + event.ptr(), + messages.as_mut_ptr(), + max_messages, + ); + Vec::from_raw_parts(messages.as_mut_ptr(), msgs_cnt, max_messages) + }; + + let ev = Arc::new(event); + for msg in messages { + let delivery_result = + unsafe { BorrowedMessage::from_dr_event(msg as *mut _, ev.clone(), self.client()) }; + let delivery_opaque = unsafe { C::DeliveryOpaque::from_ptr((*msg)._private) }; + self.context().delivery(&delivery_result, delivery_opaque); + } } /// Returns a pointer to the native Kafka client. @@ -464,12 +491,28 @@ where &self.client } + // As this library uses the rdkafka Event API, flush will not call rd_kafka_poll() but instead wait for + // the librdkafka-handled message count to reach zero. Runs until value reaches zero or timeout. fn flush>(&self, timeout: T) -> KafkaResult<()> { - let ret = unsafe { rdsys::rd_kafka_flush(self.native_ptr(), timeout.into().as_millis()) }; - if ret.is_error() { - Err(KafkaError::Flush(ret.into())) - } else { - Ok(()) + let mut timeout = timeout.into(); + loop { + let op_timeout = std::cmp::min(timeout, self.min_poll_interval); + if self.in_flight_count() > 0 { + unsafe { rdsys::rd_kafka_flush(self.native_ptr(), 0) }; + self.poll(op_timeout); + } else { + return Ok(()); + } + + if op_timeout >= timeout { + let ret = unsafe { rdsys::rd_kafka_flush(self.native_ptr(), 0) }; + if ret.is_error() { + return Err(KafkaError::Flush(ret.into())); + } else { + return Ok(()); + } + } + timeout -= op_timeout; } } @@ -534,10 +577,17 @@ where } fn commit_transaction>(&self, timeout: T) -> KafkaResult<()> { + // rd_kafka_commit_transaction will call flush but the user must call poll in order to + // server the event queue. In order to avoid blocking here forever on the base producer, + // we call Flush that will flush the outstanding messages and serve the event queue. + // https://github.com/confluentinc/librdkafka/blob/95a542c87c61d2c45b445f91c73dd5442eb04f3c/src/rdkafka.h#L10231 + // The recommended timeout here is -1 (never, i.e, infinite). + let timeout = timeout.into(); + self.flush(timeout)?; let ret = unsafe { RDKafkaError::from_ptr(rdsys::rd_kafka_commit_transaction( self.native_ptr(), - timeout.into().as_millis(), + timeout.as_millis(), )) }; if ret.is_error() { @@ -568,8 +618,13 @@ where { fn drop(&mut self) { self.purge(PurgeConfig::default().queue().inflight()); - // Still have to poll after purging to get the results that have been made ready by the purge - self.poll(Timeout::After(Duration::ZERO)); + // Still have to flush after purging to get the results that have been made ready by the purge + if let Err(err) = self.flush(Timeout::After(Duration::from_millis(500))) { + warn!( + "Failed to flush outstanding messages while dropping the producer: {:?}", + err + ); + } } } @@ -618,15 +673,11 @@ where .spawn(move || { trace!("Polling thread loop started"); loop { - let n = producer.poll(Duration::from_millis(100)); - if n == 0 { - if should_stop.load(Ordering::Relaxed) { - // We received nothing and the thread should - // stop, so break the loop. - break; - } - } else { - trace!("Received {} events", n); + producer.poll(Duration::from_millis(100)); + if should_stop.load(Ordering::Relaxed) { + // We received nothing and the thread should + // stop, so break the loop. + break; } } trace!("Polling thread loop terminated"); diff --git a/src/util.rs b/src/util.rs index 16b146f58..543481d3f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -48,6 +48,22 @@ impl Timeout { Timeout::Never => -1, } } + + /// Saturating `Duration` subtraction to Timeout. + pub(crate) fn saturating_sub(&self, rhs: Duration) -> Timeout { + match (self, rhs) { + (Timeout::After(lhs), rhs) => Timeout::After(lhs.saturating_sub(rhs)), + (Timeout::Never, _) => Timeout::Never, + } + } + + /// Returns `true` if the timeout is zero. + pub(crate) fn is_zero(&self) -> bool { + match self { + Timeout::After(d) => d.is_zero(), + Timeout::Never => false, + } + } } impl std::ops::SubAssign for Timeout { diff --git a/tests/test_high_consumers.rs b/tests/test_high_consumers.rs index d139127b0..97ca4f5a0 100644 --- a/tests/test_high_consumers.rs +++ b/tests/test_high_consumers.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use futures::future::{self, FutureExt}; use futures::stream::StreamExt; use maplit::hashmap; +use rdkafka_sys::RDKafkaErrorCode; use tokio::time::{self, Duration}; use rdkafka::consumer::{CommitMode, Consumer, ConsumerContext, StreamConsumer}; @@ -546,13 +547,41 @@ async fn test_consume_partition_order() { let mut i = 0; while i < 12 { if let Some(m) = consumer.recv().now_or_never() { - let partition = m.unwrap().partition(); + // retry on transient errors until we get a message + let m = match m { + Err(KafkaError::MessageConsumption( + RDKafkaErrorCode::BrokerTransportFailure, + )) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::AllBrokersDown)) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::OperationTimedOut)) => { + continue + } + Err(err) => { + panic!("Unexpected error receiving message: {:?}", err); + } + Ok(m) => m, + }; + let partition: i32 = m.partition(); assert!(partition == 0 || partition == 2); i += 1; } if let Some(m) = partition1.recv().now_or_never() { - assert_eq!(m.unwrap().partition(), 1); + // retry on transient errors until we get a message + let m = match m { + Err(KafkaError::MessageConsumption( + RDKafkaErrorCode::BrokerTransportFailure, + )) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::AllBrokersDown)) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::OperationTimedOut)) => { + continue + } + Err(err) => { + panic!("Unexpected error receiving message: {:?}", err); + } + Ok(m) => m, + }; + assert_eq!(m.partition(), 1); i += 1; } } diff --git a/tests/test_low_consumers.rs b/tests/test_low_consumers.rs index e1ce16bdf..c4aa305f7 100644 --- a/tests/test_low_consumers.rs +++ b/tests/test_low_consumers.rs @@ -288,13 +288,41 @@ async fn test_consume_partition_order() { let mut i = 0; while i < 12 { if let Some(m) = consumer.poll(Timeout::After(Duration::from_secs(0))) { - let partition = m.unwrap().partition(); + // retry on transient errors until we get a message + let m = match m { + Err(KafkaError::MessageConsumption( + RDKafkaErrorCode::BrokerTransportFailure, + )) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::AllBrokersDown)) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::OperationTimedOut)) => { + continue + } + Err(err) => { + panic!("Unexpected error receiving message: {:?}", err); + } + Ok(m) => m, + }; + let partition = m.partition(); assert!(partition == 0 || partition == 2); i += 1; } if let Some(m) = partition1.poll(Timeout::After(Duration::from_secs(0))) { - assert_eq!(m.unwrap().partition(), 1); + // retry on transient errors until we get a message + let m = match m { + Err(KafkaError::MessageConsumption( + RDKafkaErrorCode::BrokerTransportFailure, + )) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::AllBrokersDown)) + | Err(KafkaError::MessageConsumption(RDKafkaErrorCode::OperationTimedOut)) => { + continue + } + Err(err) => { + panic!("Unexpected error receiving message: {:?}", err); + } + Ok(m) => m, + }; + assert_eq!(m.partition(), 1); i += 1; } } diff --git a/tests/test_metadata.rs b/tests/test_metadata.rs index 3b2667a9c..e62bee556 100644 --- a/tests/test_metadata.rs +++ b/tests/test_metadata.rs @@ -22,6 +22,7 @@ fn create_consumer(group_id: &str) -> StreamConsumer { .set("session.timeout.ms", "6000") .set("api.version.request", "true") .set("debug", "all") + .set("auto.offset.reset", "earliest") .create() .expect("Failed to create StreamConsumer") }