diff --git a/rstream/consumer.py b/rstream/consumer.py index d759e72..8523c1d 100644 --- a/rstream/consumer.py +++ b/rstream/consumer.py @@ -220,7 +220,9 @@ async def subscribe( initial_credit: int = 10, properties: Optional[dict[str, Any]] = None, subscriber_name: Optional[str] = None, - consumer_update_listener: Optional[Callable[[bool, EventContext], Awaitable[Any]]] = None, + consumer_update_listener: Optional[ + Callable[[bool, EventContext], Awaitable[OffsetSpecification]] + ] = None, filter_input: Optional[FilterConfiguration] = None, ) -> str: logger.debug("Consumer subscribe()") @@ -408,7 +410,9 @@ async def _on_consumer_update_query_response( frame: schema.ConsumerUpdateResponse, subscriber: _Subscriber, reference: str, - consumer_update_listener: Optional[Callable[[bool, EventContext], Awaitable[Any]]] = None, + consumer_update_listener: Optional[ + Callable[[bool, EventContext], Awaitable[OffsetSpecification]] + ] = None, ) -> None: # event the consumer is not active, we need to send a ConsumerUpdateResponse # by protocol definition. the offsetType can't be null so we use OffsetTypeNext as default @@ -420,6 +424,8 @@ async def _on_consumer_update_query_response( is_active = bool(frame.active) event_context = EventContext(self, subscriber.reference, reference) offset_specification = await consumer_update_listener(is_active, event_context) + subscriber.offset_type = OffsetType(offset_specification.offset_type) + subscriber.offset = offset_specification.offset await subscriber.client.consumer_update(frame.correlation_id, offset_specification) async def create_stream( diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 750a350..df25263 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -12,8 +12,10 @@ AMQPMessage, Consumer, ConsumerOffsetSpecification, + EventContext, FilterConfiguration, MessageContext, + OffsetSpecification, OffsetType, OnClosedErrorInfo, Producer, @@ -28,7 +30,6 @@ from .util import ( consumer_update_handler_first, consumer_update_handler_next, - consumer_update_handler_offset, on_message, routing_extractor_generic, run_consumer, @@ -358,6 +359,47 @@ async def test_consume_multiple_streams(consumer: Consumer, producer: Producer) await asyncio.gather(*(consumer.delete_stream(stream) for stream in streams)) +async def test_consume_with_sac_custom_consumer_update_listener_cb( + consumer: Consumer, producer: Producer +) -> None: + stream_name = "stream" + await producer.create_stream(stream=stream_name) + try: + # necessary to use send_batch, since in this case, upon delivery, rabbitmq will deliver + # this batch as a whole, and not one message at a time, like send_wait + await producer.send_batch(stream_name, [AMQPMessage(body=f"{i}".encode()) for i in range(10)]) + + received_offsets = [] + + async def consumer_cb(message: bytes, message_context: MessageContext) -> None: + received_offsets.append(message_context.offset) + + async def consumer_update_listener_with_custom_offset( + is_active: bool, event_context: EventContext + ) -> OffsetSpecification: + if is_active: + return OffsetSpecification(offset_type=OffsetType.OFFSET, offset=5) + return OffsetSpecification(offset_type=OffsetType.FIRST, offset=0) + + properties = {"single-active-consumer": "true", "name": "sac_name"} + async with consumer: + await consumer.subscribe( + stream=stream_name, + callback=consumer_cb, + properties=properties, + offset_specification=ConsumerOffsetSpecification(OffsetType.FIRST), + consumer_update_listener=consumer_update_listener_with_custom_offset, + ) + + await wait_for(lambda: len(received_offsets) >= 1) + + assert received_offsets[0] == 5 + + finally: + await producer.delete_stream(stream=stream_name) + await producer.close() + + async def test_consume_superstream_with_sac_all_active( super_stream: str, super_stream_consumer_for_sac1: SuperStreamConsumer, @@ -545,11 +587,11 @@ async def test_consume_superstream_with_callback_offset( consumer_stream_list2: list[str] = [] consumer_stream_list3: list[str] = [] - await run_consumer(super_stream_consumer_for_sac1, consumer_stream_list1, consumer_update_handler_offset) - await run_consumer(super_stream_consumer_for_sac2, consumer_stream_list2, consumer_update_handler_offset) - await run_consumer(super_stream_consumer_for_sac3, consumer_stream_list3, consumer_update_handler_offset) + await run_consumer(super_stream_consumer_for_sac1, consumer_stream_list1, consumer_update_handler_first) + await run_consumer(super_stream_consumer_for_sac2, consumer_stream_list2, consumer_update_handler_first) + await run_consumer(super_stream_consumer_for_sac3, consumer_stream_list3, consumer_update_handler_first) - for i in range(10000): + for i in range(10_000): amqp_message = AMQPMessage( body=bytes("a:{}".format(i), "utf-8"), properties=Properties(message_id=str(i)), diff --git a/tests/util.py b/tests/util.py index 401b99a..48ccb0c 100644 --- a/tests/util.py +++ b/tests/util.py @@ -46,10 +46,6 @@ async def consumer_update_handler_first(is_active: bool, event_context: EventCon return OffsetSpecification(OffsetType.FIRST, 0) -async def consumer_update_handler_offset(is_active: bool, event_context: EventContext) -> OffsetSpecification: - return OffsetSpecification(OffsetType.OFFSET, 10) - - async def on_publish_confirm_client_callback( confirmation: ConfirmationStatus, confirmed_messages: list[int], errored_messages: list[int] ) -> None: