diff --git a/mpire/dashboard/dashboard.py b/mpire/dashboard/dashboard.py index 8974ce0..1a87985 100644 --- a/mpire/dashboard/dashboard.py +++ b/mpire/dashboard/dashboard.py @@ -1,6 +1,9 @@ +import atexit import getpass import importlib.resources import logging +import os +import signal import socket from datetime import datetime from multiprocessing import Event, Process @@ -11,7 +14,6 @@ from markupsafe import escape from werkzeug.serving import make_server -from mpire.signal import DisableKeyboardInterruptSignal, ignore_keyboard_interrupt from mpire.dashboard.manager import (DASHBOARD_MANAGER_CONNECTION_DETAILS, get_manager_client_dicts, shutdown_manager_server, start_manager_server) from mpire.dashboard.utils import get_two_available_ports @@ -117,35 +119,34 @@ def start_dashboard(port_range: Sequence = range(8080, 8100)) -> Dict[str, Union dashboard_port_nr, manager_port_nr = get_two_available_ports(port_range) - # Prevent signal from propagating to child process - with DisableKeyboardInterruptSignal(): + # Set up manager server + _DASHBOARD_MANAGER = start_manager_server(manager_port_nr) - # Set up manager server - _DASHBOARD_MANAGER = start_manager_server(manager_port_nr) + # Start flask server + logging.getLogger('werkzeug').setLevel(logging.WARN) + _server_process = Process(target=_run, args=(DASHBOARD_STARTED_EVENT, dashboard_port_nr, + get_manager_client_dicts()), + daemon=True, name='dashboard-process') + _server_process.start() + DASHBOARD_STARTED_EVENT.wait() - # Start flask server - logging.getLogger('werkzeug').setLevel(logging.WARN) - _server_process = Process(target=_run, args=(DASHBOARD_STARTED_EVENT, dashboard_port_nr, - get_manager_client_dicts()), - daemon=True, name='dashboard-process') - _server_process.start() - DASHBOARD_STARTED_EVENT.wait() - - # Return connect information - return {'dashboard_port_nr': dashboard_port_nr, - 'manager_host': DASHBOARD_MANAGER_CONNECTION_DETAILS.host or socket.gethostname(), - 'manager_port_nr': DASHBOARD_MANAGER_CONNECTION_DETAILS.port} + # Return connect information + return {'dashboard_port_nr': dashboard_port_nr, + 'manager_host': DASHBOARD_MANAGER_CONNECTION_DETAILS.host or socket.gethostname(), + 'manager_port_nr': DASHBOARD_MANAGER_CONNECTION_DETAILS.port} else: raise RuntimeError("You already have a running dashboard") +@atexit.register def shutdown_dashboard() -> None: """ Shuts down the dashboard """ if DASHBOARD_STARTED_EVENT.is_set(): global _server_process, _DASHBOARD_MANAGER, _DASHBOARD_TQDM_DICT, _DASHBOARD_TQDM_DETAILS_DICT if _server_process is not None: - _server_process.terminate() + # Send SIGINT to the server process, which is the only way to stop it without causing semaphore leaks + os.kill(_server_process.pid, signal.SIGINT) _server_process.join() shutdown_manager_server(_DASHBOARD_MANAGER) _DASHBOARD_MANAGER = None @@ -190,8 +191,6 @@ def _run(started: Event, dashboard_port_nr: int, manager_client_dicts: Tuple[Bas :param manager_port_nr: Dashboard manager port number :param dashboard_port_nr: Dashboard port number """ - ignore_keyboard_interrupt() # For Windows compatibility - global _DASHBOARD_TQDM_DICT, _DASHBOARD_TQDM_DETAILS_DICT _DASHBOARD_TQDM_DICT, _DASHBOARD_TQDM_DETAILS_DICT, _ = manager_client_dicts