diff --git a/src/asynqp/__init__.py b/src/asynqp/__init__.py index 3a508ea..ccd3b02 100644 --- a/src/asynqp/__init__.py +++ b/src/asynqp/__init__.py @@ -5,12 +5,13 @@ from .connection import Connection from .channel import Channel from .exchange import Exchange -from .queue import Queue, QueueBinding, Consumer +from .queue import Queue, QueueBinding, Consumer, QueuedConsumer __all__ = [ "Message", "IncomingMessage", "Connection", "Channel", "Exchange", "Queue", "QueueBinding", "Consumer", + "QueuedConsumer", "connect", "connect_and_open_channel" ] __all__ += exceptions.__all__ diff --git a/src/asynqp/exceptions.py b/src/asynqp/exceptions.py index b54a7fa..c9b1094 100644 --- a/src/asynqp/exceptions.py +++ b/src/asynqp/exceptions.py @@ -35,6 +35,10 @@ class ServerConnectionClosed(AlreadyClosed): """ Connection was closed by server """ +class ConsumerCancelled(Exception): + pass + + class UndeliverableMessage(ValueError): pass diff --git a/src/asynqp/queue.py b/src/asynqp/queue.py index 4251634..572001d 100644 --- a/src/asynqp/queue.py +++ b/src/asynqp/queue.py @@ -2,7 +2,7 @@ import re from operator import delitem from . import spec -from .exceptions import Deleted, AlreadyClosed +from .exceptions import Deleted, AlreadyClosed, ConsumerCancelled VALID_QUEUE_NAME_RE = re.compile(r'^(?!amq\.)(\w|[-.:])*$', flags=re.A) @@ -314,3 +314,75 @@ def error(self, exc): for consumer in self.consumers.values(): if hasattr(consumer.callback, 'on_error'): consumer.callback.on_error(exc) + + +class QueuedConsumer: + + def __init__(self, *, loop): + self.loop = loop + self._queue = asyncio.Queue(loop=loop) + self._exc = None + self._waiters = [] + self._cancelled = False + + def __call__(self, msg): + self._queue.put_nowait(msg) + + def on_error(self, exc): + # So future calls raise error + self._exc = exc + # Purge all messages, that were in queue. They are to be treated as + # nack'ed + while True: + try: + self._queue.get_nowait() + except asyncio.QueueEmpty: + break + + # All pending waiters if any must be killed with the same exception + for waiter in self._waiters: + waiter.set_exception(exc) + + def on_cancel(self): + self._cancelled = True + + def _check_error(self): + if self._exc: + raise self._exc + + def empty(self): + self._check_error() + return self._queue.empty() + + @asyncio.coroutine + def get(self): + """ Get 1 message from queue. Wait if no arrived yet """ + self._check_error() + + if not self.empty(): + return self._queue.get_nowait() + elif self._cancelled: + # We only cancel if no more data received from consumer + raise ConsumerCancelled() + + task = asyncio.async(self._queue.get(), loop=self.loop) + self._waiters.append(task) + # waiters will never be big. You will not want to call `.get()` + # too many times. This code serves for the case when you really want + # to call it 2 and more in parallel. + task.add_done_callback(lambda t, w=self._waiters: w.remove(t)) + return task + + @asyncio.coroutine + def getmany(self): + """ Get all accumulated messages from queue. Wait if no arrived yet """ + self._check_error() + + if self.empty(): + msg = yield from self.get() + return [msg] + else: + res = [] + while not self.empty(): + res.append(self._queue.get_nowait()) + return res