diff --git a/jupyter_dash/_stoppable_thread.py b/jupyter_dash/_stoppable_thread.py new file mode 100644 index 0000000..f6b3a74 --- /dev/null +++ b/jupyter_dash/_stoppable_thread.py @@ -0,0 +1,22 @@ +import ctypes +import threading + + +class StoppableThread(threading.Thread): + def get_id(self): + if hasattr(self, "_thread_id"): + return self._thread_id + for thread_id, thread in threading._active.items(): + if thread is self: + return thread_id + + def kill(self): + thread_id = self.get_id() + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), ctypes.py_object(SystemExit) + ) + if res == 0: + raise ValueError(f"Invalid thread id: {thread_id}") + if res > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), None) + raise SystemExit("Stopping thread failure") diff --git a/jupyter_dash/jupyter_app.py b/jupyter_dash/jupyter_app.py index 0e71ee8..8eda95b 100644 --- a/jupyter_dash/jupyter_app.py +++ b/jupyter_dash/jupyter_app.py @@ -1,9 +1,7 @@ import dash import os import requests -from flask import request import flask.cli -from threading import Thread from retrying import retry import io import re @@ -21,6 +19,7 @@ from werkzeug.debug.tbtools import get_current_traceback from .comms import _dash_comm, _jupyter_config, _request_jupyter_config +from ._stoppable_thread import StoppableThread class JupyterDash(dash.Dash): @@ -41,6 +40,8 @@ class JupyterDash(dash.Dash): _in_colab = "google.colab" in sys.modules _token = str(uuid.uuid4()) + _server_threads = {} + @classmethod def infer_jupyter_proxy_config(cls): """ @@ -130,15 +131,6 @@ def __init__(self, name=None, server_url=None, **kwargs): self.server_url = server_url - # Register route to shut down server - @self.server.route('/_shutdown_' + JupyterDash._token, methods=['GET']) - def shutdown(): - func = request.environ.get('werkzeug.server.shutdown') - if func is None: - raise RuntimeError('Not running with the Werkzeug Server') - func() - return 'Server shutting down...' - # Register route that we can use to poll to see when server is running @self.server.route('/_alive_' + JupyterDash._token, methods=['GET']) def alive(): @@ -217,7 +209,9 @@ def run_server( inline_exceptions = mode == "inline" # Terminate any existing server using this port - self._terminate_server_for_port(host, port) + old_server = self._server_threads.get((host, port)) + if old_server: + old_server.kill() # Configure pathname prefix requests_pathname_prefix = self.config.get('requests_pathname_prefix', None) @@ -291,10 +285,12 @@ def run_server( def run(): super_run_server(**kwargs) - thread = Thread(target=run) + thread = StoppableThread(target=run) thread.setDaemon(True) thread.start() + self._server_threads[(host, port)] = thread + # Wait for server to start up alive_url = "http://{host}:{port}/_alive_{token}".format( host=host, port=port, token=JupyterDash._token @@ -412,16 +408,6 @@ def _wrap_errors(_): return html_str, 500 - @classmethod - def _terminate_server_for_port(cls, host, port): - shutdown_url = "http://{host}:{port}/_shutdown_{token}".format( - host=host, port=port, token=JupyterDash._token - ) - try: - response = requests.get(shutdown_url) - except Exception as e: - pass - def _custom_formatargvalues( args, varargs, varkw, locals,