Skip to content

Commit

Permalink
(Re)Implement batching for Celery.
Browse files Browse the repository at this point in the history
  • Loading branch information
clenk committed Jun 13, 2017
1 parent 5db0eeb commit f181f91
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 15 deletions.
259 changes: 259 additions & 0 deletions libs/celery_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
"""
celery.contrib.batches
======================
Experimental task class that buffers messages and processes them as a list.
.. warning::
For this to work you have to set
:setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
the final multiplied value is higher than ``flush_every``.
In the future we hope to add the ability to direct batching tasks
to a channel with different QoS requirements than the task channel.
**Simple Example**
A click counter that flushes the buffer every 100 messages, and every
10 seconds. Does not do anything with the data, but can easily be modified
to store it in a database.
.. code-block:: python
# Flush after 100 messages, or 10 seconds.
@app.task(base=Batches, flush_every=100, flush_interval=10)
def count_click(requests):
from collections import Counter
count = Counter(request.kwargs['url'] for request in requests)
for url, count in count.items():
print('>>> Clicks: {0} -> {1}'.format(url, count))
Then you can ask for a click to be counted by doing::
>>> count_click.delay(url='http://example.com')
**Example returning results**
An interface to the Web of Trust API that flushes the buffer every 100
messages, and every 10 seconds.
.. code-block:: python
import requests
from urlparse import urlparse
from celery.contrib.batches import Batches
wot_api_target = 'https://api.mywot.com/0.4/public_link_json'
@app.task(base=Batches, flush_every=100, flush_interval=10)
def wot_api(requests):
sig = lambda url: url
reponses = wot_api_real(
(sig(*request.args, **request.kwargs) for request in requests)
)
# use mark_as_done to manually return response data
for response, request in zip(reponses, requests):
app.backend.mark_as_done(request.id, response)
def wot_api_real(urls):
domains = [urlparse(url).netloc for url in urls]
response = requests.get(
wot_api_target,
params={'hosts': ('/').join(set(domains)) + '/'}
)
return [response.json()[domain] for domain in domains]
Using the API is done as follows::
>>> wot_api.delay('http://example.com')
.. note::
If you don't have an ``app`` instance then use the current app proxy
instead::
from celery import current_app
app.backend.mark_as_done(request.id, response)
"""
from __future__ import absolute_import

from itertools import count

from celery.task import Task
from celery.five import Empty, Queue
from celery.utils.log import get_logger
from celery.worker.request import Request
from celery.utils import noop

__all__ = ['Batches']

logger = get_logger(__name__)


def consume_queue(queue):
"""Iterator yielding all immediately available items in a
:class:`Queue.Queue`.
The iterator stops as soon as the queue raises :exc:`Queue.Empty`.
*Examples*
>>> q = Queue()
>>> map(q.put, range(4))
>>> list(consume_queue(q))
[0, 1, 2, 3]
>>> list(consume_queue(q))
[]
"""
get = queue.get_nowait
while 1:
try:
yield get()
except Empty:
break


def apply_batches_task(task, args, loglevel, logfile):
task.push_request(loglevel=loglevel, logfile=logfile)
try:
result = task(*args)
except Exception as exc:
result = None
logger.error('Error: %r', exc, exc_info=True)
finally:
task.pop_request()
return result


class SimpleRequest(object):
"""Pickleable request."""

#: task id
id = None

#: task name
name = None

#: positional arguments
args = ()

#: keyword arguments
kwargs = {}

#: message delivery information.
delivery_info = None

#: worker node name
hostname = None

# added for celery 4 compatibility
chord = None
reply_to = None
correlation_id = None

def __init__(self, id, name, args, kwargs, delivery_info, hostname):
self.id = id
self.name = name
self.args = args
self.kwargs = kwargs
self.delivery_info = delivery_info
self.hostname = hostname
print(self.args)
print(self.kwargs)

@classmethod
def from_request(cls, request):
#return cls(request.id, request.name, request.args,
# request.kwargs, request.delivery_info, request.hostname)
return cls(request.id, request.name, request.body[0],
request.body[1], request.delivery_info, request.hostname)


class Batches(Task):
abstract = True

#: Maximum number of message in buffer.
flush_every = 10

#: Timeout in seconds before buffer is flushed anyway.
flush_interval = 30

def __init__(self):
self._buffer = Queue()
self._count = count(1)
self._tref = None
self._pool = None

def run(self, requests):
raise NotImplementedError('must implement run(requests)')

def Strategy(self, task, app, consumer):
self._pool = consumer.pool
hostname = consumer.hostname
eventer = consumer.event_dispatcher
Req = Request
connection_errors = consumer.connection_errors
timer = consumer.timer
put_buffer = self._buffer.put
flush_buffer = self._do_flush

def task_message_handler(message, body, ack, reject, callbacks, **kw):
#request = Req(body, on_ack=ack, app=app, hostname=hostname,
request = Req(message, body=message.payload, on_ack=ack, app=app,
events=eventer, task=task,
connection_errors=connection_errors,
delivery_info=message.delivery_info)
put_buffer(request)

if self._tref is None: # first request starts flush timer.
self._tref = timer.call_repeatedly(
self.flush_interval, flush_buffer,
)

if not next(self._count) % self.flush_every:
flush_buffer()

return task_message_handler

def flush(self, requests):
return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
for r in requests], ))

def _do_flush(self):
logger.debug('Batches: Wake-up to flush buffer...')
requests = None
if self._buffer.qsize():
requests = list(consume_queue(self._buffer))
if requests:
logger.debug('Batches: Buffer complete: %s', len(requests))
self.flush(requests)
if not requests:
logger.debug('Batches: Canceling timer: Nothing in buffer.')
if self._tref:
self._tref.cancel() # cancel timer.
self._tref = None

def apply_buffer(self, requests, args=(), kwargs={}):
acks_late = [], []
[acks_late[r.task.acks_late].append(r) for r in requests]
assert requests and (acks_late[True] or acks_late[False])

def on_accepted(pid, time_accepted):
[req.acknowledge() for req in acks_late[False]]

def on_return(result):
[req.acknowledge() for req in acks_late[True]]

return self._pool.apply_async(
apply_batches_task,
(self, args, 0, None),
accept_callback=on_accepted,
callback=acks_late[True] and on_return or noop,
)
44 changes: 29 additions & 15 deletions utils/celery_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# Add the storage dir to the sys.path. Allows import of sql_driver module
if os.path.join(MS_WD, 'storage') not in sys.path:
sys.path.insert(0, os.path.join(MS_WD, 'storage'))
# Add the libs dir to the sys.path. Allows import of common module
# Add the libs dir to the sys.path. Allows import of common, celery_batches modules
if os.path.join(MS_WD, 'libs') not in sys.path:
sys.path.insert(0, os.path.join(MS_WD, 'libs'))
import multiscanner
import common
import sql_driver as database
from celery_batches import Batches

from celery import Celery

Expand Down Expand Up @@ -56,21 +57,12 @@
))
db = database.Database(config=db_config)

@app.task
def multiscanner_celery(file_, original_filename, task_id, file_hash, config=multiscanner.CONFIG):
'''
TODO: Figure out how to do batching.
This function essentially takes in a file list and runs
multiscanner on them. Results are stored in the
storage configured in storage.ini.

Usage:
from celery_worker import multiscanner_celery
multiscanner_celery.delay([list, of, files, to, scan])
def celery_task(file_, original_filename, task_id, file_hash, config=multiscanner.CONFIG):
'''
Run multiscanner on the given file and store the results in the storage
handler(s) specified in the storage configuration file.
'''
# Initialize the connection to the task DB
db.init_db()

print('\n\n{}{}Got file: {}.\nOriginal filename: {}.\n'.format('='*48, '\n', file_hash, original_filename))

# Get the storage config
Expand All @@ -96,10 +88,32 @@ def multiscanner_celery(file_, original_filename, task_id, file_hash, config=mul
report_id=file_hash,
)

print('Results of the scan:\n{}'.format(results))
print('Results of the scan:\n{}'.format(results))

return results

# Flush after 100 messages, or 10 seconds.
@app.task(base=Batches, flush_every=100, flush_interval=10)
def multiscanner_celery(requests, *args, **kwargs):
'''
Queue up multiscanner tasks and then run a batch of them at a time for
better performance.
Usage:
from celery_worker import multiscanner_celery
multiscanner_celery.delay(full_path, original_filename, task_id,
hashed_filename, config=multiscanner.CONFIG)
'''
# Initialize the connection to the task DB
db.init_db()

for request in requests:
file_ = request.args[0]
original_filename = request.args[1]
task_id = request.args[2]
file_hash = request.args[3]
celery_task(file_, original_filename, task_id, file_hash)


if __name__ == '__main__':
app.start()

0 comments on commit f181f91

Please sign in to comment.