Skip to content

Commit

Permalink
Replace deprecated flask shutdown with thread kill. #1780
Browse files Browse the repository at this point in the history
  • Loading branch information
T4rk1n committed Mar 18, 2022
1 parent 6ee2328 commit e19dc6d
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions dash/testing/application_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import subprocess
import logging
import inspect
import ctypes

import runpy
import flask
import requests

from dash.testing.errors import NoAppFoundError, TestingTimeoutError, ServerCloseError
Expand Down Expand Up @@ -102,6 +102,26 @@ def tmp_app_path(self):
return self._tmp_app_path


class StoppableThread(threading.Thread):
def get_id(self): # pylint: disable=R1710
if hasattr(self, "_thread_id"):
return self._thread_id
for thread_id, thread in threading._active.items(): # pylint: disable=W0212
if thread is self:
return thread_id

def kill(self):
thread_id = self.get_id()
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(SystemExit)
)
if res == 0:
raise ValueError(f"Invalid thread id: {thread_id}")
if res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), None)
raise SystemExit("Stopping thread failure")


class ThreadedRunner(BaseDashRunner):
"""Runs a dash application in a thread.
Expand All @@ -110,25 +130,14 @@ class ThreadedRunner(BaseDashRunner):

def __init__(self, keep_open=False, stop_timeout=3):
super().__init__(keep_open=keep_open, stop_timeout=stop_timeout)
self.stop_route = "/_stop-{}".format(uuid.uuid4().hex)
self.thread = None

@staticmethod
def _stop_server():
# https://werkzeug.palletsprojects.com/en/0.15.x/serving/#shutting-down-the-server
stopper = flask.request.environ.get("werkzeug.server.shutdown")
if stopper is None:
raise RuntimeError("Not running with the Werkzeug Server")
stopper()
return "Flask server is shutting down"

# pylint: disable=arguments-differ
def start(self, app, **kwargs):
"""Start the app server in threading flavor."""
app.server.add_url_rule(self.stop_route, self.stop_route, self._stop_server)

def _handle_error():
self._stop_server()
self.stop()

app.server.errorhandler(500)(_handle_error)

Expand All @@ -141,7 +150,7 @@ def run():
self.port = kwargs["port"]
app.run_server(threaded=True, **kwargs)

self.thread = threading.Thread(target=run)
self.thread = StoppableThread(target=run)
self.thread.daemon = True
try:
self.thread.start()
Expand All @@ -155,7 +164,8 @@ def run():
wait.until(lambda: self.accessible(self.url), timeout=1)

def stop(self):
requests.get("{}{}".format(self.url, self.stop_route))
self.thread.kill()
self.thread.join()
wait.until_not(self.thread.is_alive, self.stop_timeout)


Expand Down

0 comments on commit e19dc6d

Please sign in to comment.