From 44ac9e056602543d9c13ce60d94ad07b1a36f7b9 Mon Sep 17 00:00:00 2001 From: Peter Steinberg Date: Mon, 2 Nov 2015 13:09:08 -0800 Subject: [PATCH 1/4] celery status server in thread background with tests --- celery_tasks.py | 14 ++- flask_server.py | 213 +++++++++++++++++++++++++++++++++++++++++-- test_flask_server.py | 156 +++++++++++++++++++++++++++++++ 3 files changed, 369 insertions(+), 14 deletions(-) create mode 100644 test_flask_server.py diff --git a/celery_tasks.py b/celery_tasks.py index 79c203ad7..36254819f 100644 --- a/celery_tasks.py +++ b/celery_tasks.py @@ -1,9 +1,12 @@ -from celery import Celery +import json import os +import time + +from celery import Celery import pandas as pd + import dropq import taxcalc -import json celery_app = Celery('tasks2', broker=os.environ['REDISGREEN_URL'], backend=os.environ['REDISGREEN_URL']) @@ -29,5 +32,8 @@ def dropq_task_async(year, user_mods): json_res = json.dumps(results) return json_res - - +@celery_app.task +def example_async(): + "example async 40 second function for testing" + time.sleep(40) + return json.dumps({'example': 'ok'}) diff --git a/flask_server.py b/flask_server.py index b3a6a0343..7ae0ca30c 100644 --- a/flask_server.py +++ b/flask_server.py @@ -1,13 +1,21 @@ +from __future__ import division, unicode_literals, print_function +import argparse +import datetime +from functools import partial +import json +import os +from threading import Thread, Event +import time + from flask import Flask, request, make_response import pandas as pd import dropq import taxcalc -import json -import time -from functools import partial from pandas.util.testing import assert_frame_equal +import requests +from retrying import retry -from celery_tasks import celery_app, dropq_task_async +from celery_tasks import celery_app, dropq_task_async, example_async import os @@ -15,23 +23,33 @@ server_url = "http://localhost:5050" +TRACKING_TICKETS = {} +SLEEP_INTERVAL_TICKET_CHECK = 30 +TRACKING_TICKETS_PATH = os.path.join(os.path.dirname(__file__), + '.TRACKING_TICKETS') +TICKET_CHECK_COUNTER = 0 +TICKET_WRITE_MOD = 200 +EXIT_EVENT = None + + @app.route("/dropq_start_job", methods=['GET', 'POST']) def dropq_endpoint(): print("stuff here") if request.method == 'POST': year = request.form['year'] user_mods = json.loads(request.form['user_mods']) - user_mods = {int(k):v for k,v in user_mods.iteritems()} + user_mods = {int(k): v for k, v in user_mods.iteritems()} else: - year = request.args.get('year','') + year = request.args.get('year', '') year = int(year) raw_results = dropq_task_async.delay(year, user_mods) return str(raw_results) + @app.route("/dropq_get_result", methods=['GET']) def dropq_results(): print("results here") - job_id = request.args.get('job_id','') + job_id = request.args.get('job_id', '') results = celery_app.AsyncResult(job_id) if results.ready(): tax_result = results.result @@ -40,9 +58,10 @@ def dropq_results(): resp = make_response('not ready', 202) return resp + @app.route("/dropq_query_result", methods=['GET']) def query_results(): - job_id = request.args.get('job_id','') + job_id = request.args.get('job_id', '') results = celery_app.AsyncResult(job_id) if results.ready(): return "YES" @@ -50,7 +69,181 @@ def query_results(): return "NO" +@app.route('/example_async', methods=['POST']) +def example(): + + job_id = example_async.delay() + print('job_id', job_id) + return json.dumps({'job_id': str(job_id)}) + + +def sleep_function(seconds): + def sleep_function_dec(func): + def new_func(*args, **kwargs): + global EXIT_EVENT + try: + while not EXIT_EVENT.isSet(): + func(*args, **kwargs) + for repeat in range(int(seconds)): + if EXIT_EVENT.isSet(): + return + time.sleep(1) + except KeyboardInterrupt: + EXIT_EVENT.set() + print('KeyboardInterrupt') + return + return new_func + return sleep_function_dec + + +class BadResponse(ValueError): + pass + + +def retry_exception_or_not(exception): + if isinstance(exception, BadResponse): + return False + return True + + +@retry(wait_fixed=7000, stop_max_attempt_number=4, + retry_on_exception=retry_exception_or_not) +def do_success_callback(job_id, callback, params): + """wait wait_fixed milliseconds between each separate attempt to + do callback this without exception. Retry it up to stop_max_attempt_number + retries, unless the exception is a BadResponse, which is a specific + error dictionary from returned by callback response that is successfully + json.loaded.""" + callback_response = None + try: + callback_response = requests.post(callback, params=params, timeout=20) + js = callback_response.json() + if 'error' in js: + # DO not retry if an error message is returned + print('ERROR (no retry) in callback_response: {0}'.format(js)) + raise BadResponse(js['error']) + return js + except Exception as e: + if callback_response is not None: + content = callback_response._content + first_message = "Failed to json.loads callback_response" + else: + first_message = "No content. Probable timeout." + content = "" + msg = first_message +\ + " with exception {0}".format(repr(e)) +\ + " for ticket id {0}:{1}".format(job_id, content) +\ + ". May retry." + print(msg) + raise + + +def job_id_check(): + global TRACKING_TICKETS + global TICKET_CHECK_COUNTER + + TICKET_CHECK_COUNTER += 1 + to_pop = [] + for job_id in TRACKING_TICKETS: + results = celery_app.AsyncResult(job_id) + if results.ready(): + to_pop.append(job_id) + for job_id in to_pop: + ticket_dict = TRACKING_TICKETS.pop(job_id) + callback = ticket_dict['callback'] + # TODO decide on exception handling here + # raise exception if 1 ticket's callback fails? or just log it? + resp = do_success_callback(job_id, callback, ticket_dict['params']) + print('Success on callback {0} with response {1}'.format(callback, + resp)) + if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD: + # periodically dump the TRACKING_TICKETS to json + TICKET_CHECK_COUNTER = 0 + with open(TRACKING_TICKETS_PATH, 'w') as f: + f.write(json.dumps(TRACKING_TICKETS)) + + +@app.route("/register_job", methods=['POST']) +def register_job(): + global TRACKING_TICKETS + callback = request.args.get('callback', False) + job_id = request.args.get('job_id', False) + params = json.loads(request.args.get('params', "{}")) + if not callback or not job_id: + return make_response(json.dumps({'error': "Expected arguments of " + + "job_id, callback, and " + + "optionally params."}), 400) + msg = "Start checking: job_id {0} with params {1} and callback {2}" + print(msg.format(job_id, params, callback)) + now = datetime.datetime.utcnow().isoformat() + TRACKING_TICKETS[job_id] = {'params': params, + 'started': now, + 'callback': callback, + 'job_id': job_id, + } + return json.dumps({'registered': TRACKING_TICKETS[job_id], }) + + +@app.route('/pop_job_id', methods=['POST']) +def pop(): + global TRACKING_TICKETS + job_id = request.args.get('job_id', '') + if job_id in TRACKING_TICKETS: + return json.dumps({'popped': TRACKING_TICKETS.pop(job_id)}) + return make_response(json.dumps({'job_id': job_id, + 'error': 'job_id not present'}), 400) + + +@app.route('/current_tickets_tracker', methods=['GET']) +def current_tickets_tracker(): + return json.dumps(TRACKING_TICKETS) + + +@app.route('/example_success_callback', methods=['POST']) +def example_success_callback(): + return json.dumps({'ok': 'example_success_callback'}) + + +def cli(): + parser = argparse.ArgumentParser(description="Run flask server") + parser.add_argument('-s', '--sleep-interval', + help="How long to sleep between status checks", + default=SLEEP_INTERVAL_TICKET_CHECK, + type=float) + parser.add_argument('-p', '--port', help="Port on which to run", + default=5050, required=False) + parser.add_argument('-i', '--ignore-cached-tickets', + help="Skip the loading of tracking tickets from " + + "{0}. Testing only.".format(TRACKING_TICKETS_PATH)) + return parser.parse_args() + + +def main(): + global TRACKING_TICKETS + global EXIT_EVENT + EXIT_EVENT = Event() + args = cli() + if os.path.exists(TRACKING_TICKETS_PATH) and not args.ignore_cached_tickets: + # load any tickets if they exist + with open(TRACKING_TICKETS_PATH, 'r') as f: + TRACKING_TICKETS = json.load(f) + @sleep_function(args.sleep_interval) + def checking_tickets_at_interval(): + return job_id_check() + try: + checker_thread = Thread(target=checking_tickets_at_interval) + checker_thread.daemon = True + checker_thread.start() + app.debug = True + app.run(host='0.0.0.0', port=args.port) + except Exception as e: + EXIT_EVENT.set() + time.sleep(3) + raise + finally: + # dump all the standing tickets no matter what + with open(TRACKING_TICKETS_PATH, 'w') as f: + f.write(json.dumps(TRACKING_TICKETS)) if __name__ == "__main__": - app.debug = True - app.run(host='0.0.0.0', port=5050) + main() diff --git a/test_flask_server.py b/test_flask_server.py new file mode 100644 index 000000000..2ab927eb8 --- /dev/null +++ b/test_flask_server.py @@ -0,0 +1,156 @@ +"""Testing protocol: +In four separate terminals on localhost, do these four commands: + +redis-server + +REDISGREEN_URL=redis://localhost:6379 celery -A celery_tasks worker -P eventlet -l info + +REDISGREEN_URL=redis://localhost:6379 python flask_server.py + +REDISGREEN_URL=redis://localhost:6379 py.test ./ -p no:django -m has_services --pep8 + +Takes about 2 minutes to run. +""" +from __future__ import division, unicode_literals, print_function +import subprocess as sp +import time + +import psutil +import pytest +import requests + +from celery_tasks import celery_app, dropq_task_async, example_async +import flask_server + +service_exes_required = ('python', 'celery', 'redis-server') + + +def flask_url(ending): + return "{0}{1}".format(flask_server.server_url, ending) + +DEFAULT_PARAMS = { + 'callback': flask_url('/example_success_callback'), + 'params': '{}', + +} +NOT_RUNNING_MSG = "Trying to run a test marked with has_services " +\ + "but services are not running." + + +def has_services_for_test(): + found_services = set() + for p in psutil.process_iter(): + try: + # may not have access to all processes, so try + for service in service_exes_required: + if any(service in cmdi for cmdi in p.cmdline()): + found_services.add(service) + except psutil.AccessDenied: + pass + if not len(found_services) == 3: + info = (service_exes_required, found_services) + msg = 'Expected services of {0} but found {1}'.format(*info) + raise ValueError(msg) + + response = requests.post(flask_url('/example_async')) + if not response.status_code == 200: + msg = 'Bad response from /example_async:{0}'.format(response._content) + raise ValueError(msg) + + try: + js = response.json() + if not 'job_id' in js: + msg = 'Json.loaded response from ' +\ + '/example_async but no job_id {}'.format(js) + raise ValueError(msg) + return True + except Exception as e: + msg = '''Got 200 response from /example_async but +could not json.loads:{0}'''.format(response._content) + raise ValueError(msg) + + +def request_assert(method, url_ending, params=None, + status_code=200, as_json=True): + if not params: + params = {} + response = getattr(requests, method)(flask_url(url_ending), params=params) + assert response.status_code == status_code + if as_json: + return response.json() + return response + + +def example_with_register(): + 'Do an example_async task and register the job_id' + example = request_assert('post', '/example_async') + job_id = example['job_id'] + params = DEFAULT_PARAMS.copy() + params['job_id'] = job_id + register = request_assert('post', + '/register_job', + params=params) + return (register, job_id) + + +@pytest.mark.has_services +def test_register_ticket_process(): + '''After submitting an example, we should see its results + eventually.''' + assert has_services_for_test(), NOT_RUNNING_MSG + register, job_id = example_with_register() + err = None + for repeats in range(8): + try: + results = request_assert('get', + '/dropq_query_result', + params={'job_id': job_id}, + as_json=False) + assert results._content == "YES" + return + + except Exception as err: + time.sleep(8) + if err is not None: + raise + + +@pytest.mark.has_services +def test_pop_and_inspect_celery(): + '''Can only get a 'popped' key back if the job_id has been registered''' + assert has_services_for_test(), NOT_RUNNING_MSG + register, job_id = example_with_register() + time.sleep(4) + popped = request_assert('post', + '/pop_job_id', + params={'job_id': job_id}) + assert 'popped' in popped + ticket_dict = popped['popped'] + assert 'callback' in ticket_dict + assert 'job_id' in ticket_dict + + +@pytest.mark.has_services +def test_tracking_tickets(): + '''Test job_id disappears after job completion''' + assert has_services_for_test(), NOT_RUNNING_MSG + job_ids = [] + for repeat in range(3): + register, job_id = example_with_register() + job_ids.append(job_id) + time.sleep(4) + tickets = request_assert('get', '/current_tickets_tracker') + for job_id in job_ids: + assert job_id in tickets + for repeat in range(3): + time.sleep(30) + # we have waited through completion of jobs + tickets = request_assert('get', '/current_tickets_tracker') + try: + assert not any(job_id in tickets for job_id in job_ids) + return + except: + pass + raise + + From 2cece434a1f61ebec558ab5149b1217e4b24e20f Mon Sep 17 00:00:00 2001 From: Peter Steinberg Date: Mon, 2 Nov 2015 14:05:19 -0800 Subject: [PATCH 2/4] use threading.Lock around TRACKING_TICKETS dict --- flask_server.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/flask_server.py b/flask_server.py index 7ae0ca30c..2284c7378 100644 --- a/flask_server.py +++ b/flask_server.py @@ -4,7 +4,7 @@ from functools import partial import json import os -from threading import Thread, Event +from threading import Thread, Event, Lock import time from flask import Flask, request, make_response @@ -30,7 +30,7 @@ TICKET_CHECK_COUNTER = 0 TICKET_WRITE_MOD = 200 EXIT_EVENT = None - +TICKET_LOCK = None @app.route("/dropq_start_job", methods=['GET', 'POST']) def dropq_endpoint(): @@ -141,7 +141,7 @@ def do_success_callback(job_id, callback, params): def job_id_check(): global TRACKING_TICKETS global TICKET_CHECK_COUNTER - + TICKET_LOCK.acquire() TICKET_CHECK_COUNTER += 1 to_pop = [] for job_id in TRACKING_TICKETS: @@ -156,16 +156,19 @@ def job_id_check(): resp = do_success_callback(job_id, callback, ticket_dict['params']) print('Success on callback {0} with response {1}'.format(callback, resp)) + if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD: # periodically dump the TRACKING_TICKETS to json TICKET_CHECK_COUNTER = 0 with open(TRACKING_TICKETS_PATH, 'w') as f: f.write(json.dumps(TRACKING_TICKETS)) + TICKET_LOCK.release() @app.route("/register_job", methods=['POST']) def register_job(): global TRACKING_TICKETS + global TICKET_LOCK callback = request.args.get('callback', False) job_id = request.args.get('job_id', False) params = json.loads(request.args.get('params', "{}")) @@ -176,22 +179,30 @@ def register_job(): msg = "Start checking: job_id {0} with params {1} and callback {2}" print(msg.format(job_id, params, callback)) now = datetime.datetime.utcnow().isoformat() + TICKET_LOCK.acquire() TRACKING_TICKETS[job_id] = {'params': params, 'started': now, 'callback': callback, 'job_id': job_id, } + TICKET_LOCK.release() return json.dumps({'registered': TRACKING_TICKETS[job_id], }) @app.route('/pop_job_id', methods=['POST']) def pop(): + global TICKET_LOCK global TRACKING_TICKETS + TICKET_LOCK.acquire() + job_id = request.args.get('job_id', '') if job_id in TRACKING_TICKETS: - return json.dumps({'popped': TRACKING_TICKETS.pop(job_id)}) - return make_response(json.dumps({'job_id': job_id, + resp = json.dumps({'popped': TRACKING_TICKETS.pop(job_id)}) + else: + resp = make_response(json.dumps({'job_id': job_id, 'error': 'job_id not present'}), 400) + TICKET_LOCK.release() + return resp @app.route('/current_tickets_tracker', methods=['GET']) @@ -221,7 +232,9 @@ def cli(): def main(): global TRACKING_TICKETS global EXIT_EVENT + global TICKET_LOCK EXIT_EVENT = Event() + TICKET_LOCK = Lock() args = cli() if os.path.exists(TRACKING_TICKETS_PATH) and not args.ignore_cached_tickets: # load any tickets if they exist From 1d80841aa7a71b4bafe6c0769706ef4893d65817 Mon Sep 17 00:00:00 2001 From: Peter Steinberg Date: Tue, 3 Nov 2015 10:51:14 -0800 Subject: [PATCH 3/4] contextmanager for lock around tickets dictionary --- flask_server.py | 98 ++++++++++++++++++++++++-------------------- test_flask_server.py | 1 - 2 files changed, 53 insertions(+), 46 deletions(-) diff --git a/flask_server.py b/flask_server.py index 2284c7378..c7e5b203d 100644 --- a/flask_server.py +++ b/flask_server.py @@ -1,5 +1,6 @@ from __future__ import division, unicode_literals, print_function import argparse +from contextlib import contextmanager import datetime from functools import partial import json @@ -28,10 +29,20 @@ TRACKING_TICKETS_PATH = os.path.join(os.path.dirname(__file__), '.TRACKING_TICKETS') TICKET_CHECK_COUNTER = 0 -TICKET_WRITE_MOD = 200 +TICKET_WRITE_MOD = int(os.environ.get('TICKET_WRITE_MOD', '1')) EXIT_EVENT = None TICKET_LOCK = None +@contextmanager +def ticket_lock_context(): + global TICKET_LOCK + TICKET_LOCK.acquire() + try: + yield + finally: + TICKET_LOCK.release() + + @app.route("/dropq_start_job", methods=['GET', 'POST']) def dropq_endpoint(): print("stuff here") @@ -141,34 +152,32 @@ def do_success_callback(job_id, callback, params): def job_id_check(): global TRACKING_TICKETS global TICKET_CHECK_COUNTER - TICKET_LOCK.acquire() - TICKET_CHECK_COUNTER += 1 - to_pop = [] - for job_id in TRACKING_TICKETS: - results = celery_app.AsyncResult(job_id) - if results.ready(): - to_pop.append(job_id) - for job_id in to_pop: - ticket_dict = TRACKING_TICKETS.pop(job_id) - callback = ticket_dict['callback'] - # TODO decide on exception handling here - # raise exception if 1 ticket's callback fails? or just log it? - resp = do_success_callback(job_id, callback, ticket_dict['params']) - print('Success on callback {0} with response {1}'.format(callback, - resp)) - - if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD: - # periodically dump the TRACKING_TICKETS to json - TICKET_CHECK_COUNTER = 0 - with open(TRACKING_TICKETS_PATH, 'w') as f: - f.write(json.dumps(TRACKING_TICKETS)) - TICKET_LOCK.release() + with ticket_lock_context(): + TICKET_CHECK_COUNTER += 1 + to_pop = [] + for job_id in TRACKING_TICKETS: + results = celery_app.AsyncResult(job_id) + if results.ready(): + to_pop.append(job_id) + for job_id in to_pop: + ticket_dict = TRACKING_TICKETS.pop(job_id) + callback = ticket_dict['callback'] + # TODO decide on exception handling here + # raise exception if 1 ticket's callback fails? or just log it? + resp = do_success_callback(job_id, callback, ticket_dict['params']) + print('Success on callback {0} with response {1}'.format(callback, + resp)) + + if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD: + # periodically dump the TRACKING_TICKETS to json + TICKET_CHECK_COUNTER = 0 + with open(TRACKING_TICKETS_PATH, 'w') as f: + f.write(json.dumps(TRACKING_TICKETS)) @app.route("/register_job", methods=['POST']) def register_job(): global TRACKING_TICKETS - global TICKET_LOCK callback = request.args.get('callback', False) job_id = request.args.get('job_id', False) params = json.loads(request.args.get('params', "{}")) @@ -179,35 +188,33 @@ def register_job(): msg = "Start checking: job_id {0} with params {1} and callback {2}" print(msg.format(job_id, params, callback)) now = datetime.datetime.utcnow().isoformat() - TICKET_LOCK.acquire() - TRACKING_TICKETS[job_id] = {'params': params, - 'started': now, - 'callback': callback, - 'job_id': job_id, - } - TICKET_LOCK.release() - return json.dumps({'registered': TRACKING_TICKETS[job_id], }) + with ticket_lock_context(): + TRACKING_TICKETS[job_id] = {'params': params, + 'started': now, + 'callback': callback, + 'job_id': job_id, + } + return json.dumps({'registered': TRACKING_TICKETS[job_id], }) @app.route('/pop_job_id', methods=['POST']) def pop(): - global TICKET_LOCK global TRACKING_TICKETS - TICKET_LOCK.acquire() + with ticket_lock_context(): - job_id = request.args.get('job_id', '') - if job_id in TRACKING_TICKETS: - resp = json.dumps({'popped': TRACKING_TICKETS.pop(job_id)}) - else: - resp = make_response(json.dumps({'job_id': job_id, - 'error': 'job_id not present'}), 400) - TICKET_LOCK.release() - return resp + job_id = request.args.get('job_id', '') + if job_id in TRACKING_TICKETS: + resp = json.dumps({'popped': TRACKING_TICKETS.pop(job_id)}) + else: + resp = make_response(json.dumps({'job_id': job_id, + 'error': 'job_id not present'}), 400) + return resp @app.route('/current_tickets_tracker', methods=['GET']) def current_tickets_tracker(): - return json.dumps(TRACKING_TICKETS) + with ticket_lock_context(): + return json.dumps(TRACKING_TICKETS) @app.route('/example_success_callback', methods=['POST']) @@ -256,7 +263,8 @@ def checking_tickets_at_interval(): raise finally: # dump all the standing tickets no matter what - with open(TRACKING_TICKETS_PATH, 'w') as f: - f.write(json.dumps(TRACKING_TICKETS)) + with ticket_lock_context(): + with open(TRACKING_TICKETS_PATH, 'w') as f: + f.write(json.dumps(TRACKING_TICKETS)) if __name__ == "__main__": main() diff --git a/test_flask_server.py b/test_flask_server.py index 2ab927eb8..4726ce137 100644 --- a/test_flask_server.py +++ b/test_flask_server.py @@ -120,7 +120,6 @@ def test_pop_and_inspect_celery(): '''Can only get a 'popped' key back if the job_id has been registered''' assert has_services_for_test(), NOT_RUNNING_MSG register, job_id = example_with_register() - time.sleep(4) popped = request_assert('post', '/pop_job_id', params={'job_id': job_id}) From 150d9223c6e106157ac1c87d56676f616a070581 Mon Sep 17 00:00:00 2001 From: Peter Steinberg Date: Tue, 3 Nov 2015 13:52:55 -0800 Subject: [PATCH 4/4] reduce writing of json status by has_changed flag --- flask_server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/flask_server.py b/flask_server.py index c7e5b203d..e369401f0 100644 --- a/flask_server.py +++ b/flask_server.py @@ -1,6 +1,7 @@ from __future__ import division, unicode_literals, print_function import argparse from contextlib import contextmanager +import copy import datetime from functools import partial import json @@ -153,6 +154,7 @@ def job_id_check(): global TRACKING_TICKETS global TICKET_CHECK_COUNTER with ticket_lock_context(): + old_tracking_tickets = copy.deepcopy(TRACKING_TICKETS) TICKET_CHECK_COUNTER += 1 to_pop = [] for job_id in TRACKING_TICKETS: @@ -170,9 +172,11 @@ def job_id_check(): if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD: # periodically dump the TRACKING_TICKETS to json - TICKET_CHECK_COUNTER = 0 - with open(TRACKING_TICKETS_PATH, 'w') as f: - f.write(json.dumps(TRACKING_TICKETS)) + # if changed + if old_tracking_tickets != TRACKING_TICKETS: + TICKET_CHECK_COUNTER = 0 + with open(TRACKING_TICKETS_PATH, 'w') as f: + f.write(json.dumps(TRACKING_TICKETS)) @app.route("/register_job", methods=['POST'])