Skip to content

Commit

Permalink
Merge pull request PSLmodels#30 from PeterDSteinberg/feature/celery_s…
Browse files Browse the repository at this point in the history
…tatus_server

celery status server in thread background with tests
  • Loading branch information
talumbau committed Nov 12, 2015
2 parents 29f126e + 150d922 commit 58a1549
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 14 deletions.
14 changes: 10 additions & 4 deletions celery_tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from celery import Celery
import json
import os
import time

from celery import Celery
import pandas as pd

import dropq
import taxcalc
import json

celery_app = Celery('tasks2', broker=os.environ['REDISGREEN_URL'], backend=os.environ['REDISGREEN_URL'])

Expand Down Expand Up @@ -33,5 +36,8 @@ def dropq_task_async(year, user_mods):
json_res = json.dumps(results)
return json_res



@celery_app.task
def example_async():
"example async 40 second function for testing"
time.sleep(40)
return json.dumps({'example': 'ok'})
238 changes: 228 additions & 10 deletions flask_server.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,67 @@
from __future__ import division, unicode_literals, print_function
import argparse
from contextlib import contextmanager
import copy
import datetime
from functools import partial
import json
import os
from threading import Thread, Event, Lock
import time

from flask import Flask, request, make_response
import pandas as pd
import dropq
import taxcalc
import json
import time
from functools import partial
from pandas.util.testing import assert_frame_equal
import requests
from retrying import retry

from celery_tasks import celery_app, dropq_task_async
from celery_tasks import celery_app, dropq_task_async, example_async

import os

app = Flask('sampleapp')

server_url = "http://localhost:5050"

TRACKING_TICKETS = {}
SLEEP_INTERVAL_TICKET_CHECK = 30
TRACKING_TICKETS_PATH = os.path.join(os.path.dirname(__file__),
'.TRACKING_TICKETS')
TICKET_CHECK_COUNTER = 0
TICKET_WRITE_MOD = int(os.environ.get('TICKET_WRITE_MOD', '1'))
EXIT_EVENT = None
TICKET_LOCK = None

@contextmanager
def ticket_lock_context():
global TICKET_LOCK
TICKET_LOCK.acquire()
try:
yield
finally:
TICKET_LOCK.release()


@app.route("/dropq_start_job", methods=['GET', 'POST'])
def dropq_endpoint():
print("stuff here")
if request.method == 'POST':
year = request.form['year']
user_mods = json.loads(request.form['user_mods'])
user_mods = {int(k):v for k,v in user_mods.iteritems()}
user_mods = {int(k): v for k, v in user_mods.iteritems()}
else:
year = request.args.get('year','')
year = request.args.get('year', '')
year = int(year)
raw_results = dropq_task_async.delay(year, user_mods)
return str(raw_results)


@app.route("/dropq_get_result", methods=['GET'])
def dropq_results():
print("results here")
job_id = request.args.get('job_id','')
job_id = request.args.get('job_id', '')
results = celery_app.AsyncResult(job_id)
if results.ready():
tax_result = results.result
Expand All @@ -40,17 +70,205 @@ def dropq_results():
resp = make_response('not ready', 202)
return resp


@app.route("/dropq_query_result", methods=['GET'])
def query_results():
job_id = request.args.get('job_id','')
job_id = request.args.get('job_id', '')
results = celery_app.AsyncResult(job_id)
if results.ready():
return "YES"
else:
return "NO"


@app.route('/example_async', methods=['POST'])
def example():

job_id = example_async.delay()
print('job_id', job_id)
return json.dumps({'job_id': str(job_id)})


def sleep_function(seconds):
def sleep_function_dec(func):
def new_func(*args, **kwargs):
global EXIT_EVENT
try:
while not EXIT_EVENT.isSet():
func(*args, **kwargs)
for repeat in range(int(seconds)):
if EXIT_EVENT.isSet():
return
time.sleep(1)
except KeyboardInterrupt:
EXIT_EVENT.set()
print('KeyboardInterrupt')
return
return new_func
return sleep_function_dec


class BadResponse(ValueError):
pass


def retry_exception_or_not(exception):
if isinstance(exception, BadResponse):
return False
return True


@retry(wait_fixed=7000, stop_max_attempt_number=4,
retry_on_exception=retry_exception_or_not)
def do_success_callback(job_id, callback, params):
"""wait wait_fixed milliseconds between each separate attempt to
do callback this without exception. Retry it up to stop_max_attempt_number
retries, unless the exception is a BadResponse, which is a specific
error dictionary from returned by callback response that is successfully
json.loaded."""
callback_response = None
try:
callback_response = requests.post(callback, params=params, timeout=20)
js = callback_response.json()
if 'error' in js:
# DO not retry if an error message is returned
print('ERROR (no retry) in callback_response: {0}'.format(js))
raise BadResponse(js['error'])
return js
except Exception as e:
if callback_response is not None:
content = callback_response._content
first_message = "Failed to json.loads callback_response"
else:
first_message = "No content. Probable timeout."
content = ""
msg = first_message +\
" with exception {0}".format(repr(e)) +\
" for ticket id {0}:{1}".format(job_id, content) +\
". May retry."
print(msg)
raise


def job_id_check():
global TRACKING_TICKETS
global TICKET_CHECK_COUNTER
with ticket_lock_context():
old_tracking_tickets = copy.deepcopy(TRACKING_TICKETS)
TICKET_CHECK_COUNTER += 1
to_pop = []
for job_id in TRACKING_TICKETS:
results = celery_app.AsyncResult(job_id)
if results.ready():
to_pop.append(job_id)
for job_id in to_pop:
ticket_dict = TRACKING_TICKETS.pop(job_id)
callback = ticket_dict['callback']
# TODO decide on exception handling here
# raise exception if 1 ticket's callback fails? or just log it?
resp = do_success_callback(job_id, callback, ticket_dict['params'])
print('Success on callback {0} with response {1}'.format(callback,
resp))

if TICKET_CHECK_COUNTER >= TICKET_WRITE_MOD:
# periodically dump the TRACKING_TICKETS to json
# if changed
if old_tracking_tickets != TRACKING_TICKETS:
TICKET_CHECK_COUNTER = 0
with open(TRACKING_TICKETS_PATH, 'w') as f:
f.write(json.dumps(TRACKING_TICKETS))


@app.route("/register_job", methods=['POST'])
def register_job():
global TRACKING_TICKETS
callback = request.args.get('callback', False)
job_id = request.args.get('job_id', False)
params = json.loads(request.args.get('params', "{}"))
if not callback or not job_id:
return make_response(json.dumps({'error': "Expected arguments of " +
"job_id, callback, and " +
"optionally params."}), 400)
msg = "Start checking: job_id {0} with params {1} and callback {2}"
print(msg.format(job_id, params, callback))
now = datetime.datetime.utcnow().isoformat()
with ticket_lock_context():
TRACKING_TICKETS[job_id] = {'params': params,
'started': now,
'callback': callback,
'job_id': job_id,
}
return json.dumps({'registered': TRACKING_TICKETS[job_id], })


@app.route('/pop_job_id', methods=['POST'])
def pop():
global TRACKING_TICKETS
with ticket_lock_context():

job_id = request.args.get('job_id', '')
if job_id in TRACKING_TICKETS:
resp = json.dumps({'popped': TRACKING_TICKETS.pop(job_id)})
else:
resp = make_response(json.dumps({'job_id': job_id,
'error': 'job_id not present'}), 400)
return resp


@app.route('/current_tickets_tracker', methods=['GET'])
def current_tickets_tracker():
with ticket_lock_context():
return json.dumps(TRACKING_TICKETS)


@app.route('/example_success_callback', methods=['POST'])
def example_success_callback():
return json.dumps({'ok': 'example_success_callback'})


def cli():
parser = argparse.ArgumentParser(description="Run flask server")
parser.add_argument('-s', '--sleep-interval',
help="How long to sleep between status checks",
default=SLEEP_INTERVAL_TICKET_CHECK,
type=float)
parser.add_argument('-p', '--port', help="Port on which to run",
default=5050, required=False)
parser.add_argument('-i', '--ignore-cached-tickets',
help="Skip the loading of tracking tickets from " +
"{0}. Testing only.".format(TRACKING_TICKETS_PATH))
return parser.parse_args()


def main():
global TRACKING_TICKETS
global EXIT_EVENT
global TICKET_LOCK
EXIT_EVENT = Event()
TICKET_LOCK = Lock()
args = cli()
if os.path.exists(TRACKING_TICKETS_PATH) and not args.ignore_cached_tickets:
# load any tickets if they exist
with open(TRACKING_TICKETS_PATH, 'r') as f:
TRACKING_TICKETS = json.load(f)

@sleep_function(args.sleep_interval)
def checking_tickets_at_interval():
return job_id_check()
try:
checker_thread = Thread(target=checking_tickets_at_interval)
checker_thread.daemon = True
checker_thread.start()
app.debug = True
app.run(host='0.0.0.0', port=args.port)
except Exception as e:
EXIT_EVENT.set()
time.sleep(3)
raise
finally:
# dump all the standing tickets no matter what
with ticket_lock_context():
with open(TRACKING_TICKETS_PATH, 'w') as f:
f.write(json.dumps(TRACKING_TICKETS))
if __name__ == "__main__":
app.debug = True
app.run(host='0.0.0.0', port=5050)
main()
Loading

0 comments on commit 58a1549

Please sign in to comment.