Skip to content

Commit

Permalink
- TQDM connection details updated. There was no need to put them into…
Browse files Browse the repository at this point in the history
… process aware objects (#114)

- Progress bar didn't need connection details to be serialized, as a thread simply uses the same memory pool.
- Progress bar also works on Windows now.

Co-authored-by: sybrenjansen <sybren.jansen@gmail.com>
  • Loading branch information
sybrenjansen and sybrenjansen authored Jan 3, 2024
1 parent d78ad96 commit 62d4530
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 86 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Unreleased
* Import escape directly from markupsafe, instead of from flask. (`#106`_)
* Insights now also work when using the ``forkserver`` and ``spawn`` start methods. (`#104`_)
* When using insights on Windows the arguments of the top 5 longest tasks are now available as well.
* Progress bars are now supported on Windows.

.. _#108: https://github.com/sybrenjansen/mpire/pull/107
.. _#107: https://github.com/sybrenjansen/mpire/issues/106
Expand Down
1 change: 0 additions & 1 deletion docs/troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ Windows

Windows support has some caveats:

* Progress bar is not supported when using threading as start method;
* When using ``dill`` and an exception occurs, or when the exception occurs in an exit function, it can print additional
``OSError`` messages in the terminal, but they can be safely ignored.

Expand Down
7 changes: 1 addition & 6 deletions mpire/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from unittest.mock import patch
from tqdm import tqdm, TqdmKeyError

from mpire.context import DEFAULT_START_METHOD, RUNNING_WINDOWS
from mpire.context import DEFAULT_START_METHOD
from mpire.tqdm_utils import get_tqdm

# Typedefs
Expand Down Expand Up @@ -207,11 +207,6 @@ def check_map_parameters(pool_params: WorkerPoolParams, iterable_of_args: Union[
# If worker lifespan is not None or not a positive integer, raise
check_number(worker_lifespan, 'worker_lifespan', allowed_types=(int,), none_allowed=True, min_=1)

# Progress bar is currently not supported on Windows when using threading as start method. For reasons still
# unknown we get a TypeError: cannot pickle '_thread.Lock' object.
if RUNNING_WINDOWS and progress_bar and pool_params.start_method == "threading":
raise ValueError("Progress bar is currently not supported on Windows when using start_method='threading'")

# Check progress bar parameters and set default values
progress_bar_options = check_progress_bar_options(progress_bar_options, progress_bar_position, n_tasks,
progress_bar_style)
Expand Down
19 changes: 3 additions & 16 deletions mpire/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from tqdm import tqdm as tqdm_type

from mpire.comms import WorkerComms, POISON_PILL
from mpire.dashboard.connection_utils import (DashboardConnectionDetails, get_dashboard_connection_details,
set_dashboard_connection)
from mpire.exception import remove_highlighting
from mpire.insights import WorkerInsights
from mpire.params import WorkerMapParams, WorkerPoolParams
from mpire.signal import DisableKeyboardInterruptSignal
from mpire.tqdm_utils import get_tqdm, TqdmConnectionDetails, TqdmManager
from mpire.tqdm_utils import get_tqdm, TqdmManager
from mpire.utils import format_seconds

# If a user has not installed the dashboard dependencies than the imports below will fail
Expand Down Expand Up @@ -81,8 +79,7 @@ def __enter__(self) -> 'ProgressBarHandler':

# Disable the interrupt signal. We let the thread die gracefully
with DisableKeyboardInterruptSignal():
self.thread = Thread(target=self._progress_bar_handler, args=(TqdmManager.get_connection_details(),
get_dashboard_connection_details()))
self.thread = Thread(target=self._progress_bar_handler)
self.thread.start()
self.thread_started.wait()

Expand All @@ -104,23 +101,13 @@ def __exit__(self, exc_type: Type, *_) -> None:
self.worker_comms.signal_progress_bar_shutdown()
self.thread.join()

def _progress_bar_handler(self, tqdm_connection_details: TqdmConnectionDetails,
dashboard_connection_details: DashboardConnectionDetails) -> None:
def _progress_bar_handler(self) -> None:
"""
Keeps track of the progress made by the workers and updates the progress bar accordingly
:param tqdm_connection_details: Tqdm manager host, and whether the manager is started/connected
:param dashboard_connection_details: Dashboard manager host, port_nr and whether a dashboard is
started/connected
"""
# Obtain the progress bar tqdm class
tqdm, in_notebook = get_tqdm(self.progress_bar_style)

# Set tqdm and dashboard connection details. This is needed for nested pools and in the case forkserver or
# spawn is used as start method
TqdmManager.set_connection_details(tqdm_connection_details)
set_dashboard_connection(dashboard_connection_details)

# Connect to the tqdm manager
tqdm_manager = TqdmManager()
tqdm_lock, tqdm_position_register = tqdm_manager.get_lock_and_position_register()
Expand Down
72 changes: 29 additions & 43 deletions mpire/tqdm_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging
from ctypes import c_char
from multiprocessing import Array, Event, Lock
from multiprocessing import Lock
from multiprocessing.managers import SyncManager
from typing import Optional, Tuple, Type
import os
from typing import Optional, Tuple, Type, Union

from tqdm import tqdm as tqdm_std
from tqdm.notebook import tqdm as tqdm_notebook

from mpire.signal import DisableKeyboardInterruptSignal

PROGRESS_BAR_DEFAULT_STYLE = 'std'
TqdmConnectionDetails = Tuple[Optional[bytes], bool]
TqdmConnectionDetails = Tuple[Union[str, bytes, None], bytes, bool]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,15 +118,16 @@ class TqdmManager:
POSITION_REGISTER = TqdmPositionRegister()

MANAGER = None
MANAGER_HOST = Array(c_char, 10000, lock=True)
MANAGER_STARTED = Event()
MANAGER_HOST: Union[str, bytes, None] = None
MANAGER_AUTHKEY = b""
MANAGER_STARTED = False

def __init__(self) -> None:
"""
Tqdm manager wrapper for syncing multiple progress bars, independent of process start method used.
"""
# Connect to existing manager, if it exists
if self.MANAGER_STARTED.is_set():
if self.MANAGER_STARTED:
self.connect_to_manager()

@classmethod
Expand All @@ -137,49 +138,33 @@ def start_manager(cls) -> bool:
:return: Whether the manager was started
"""
# Don't do anything when there's already a connected tqdm manager
if cls.MANAGER_STARTED.is_set():
if cls.MANAGER_STARTED:
return False

logger.debug("Starting TQDM manager")

# Create manager
with DisableKeyboardInterruptSignal():
cls.MANAGER = SyncManager(authkey=b'mpire_tqdm')
cls.MANAGER = SyncManager(authkey=os.urandom(24))
cls.MANAGER.register('get_tqdm_lock', cls._get_tqdm_lock)
cls.MANAGER.register('get_tqdm_position_register', cls._get_tqdm_position_register)
cls.MANAGER.start()
cls.MANAGER_STARTED.set()

# Set host so other processes know where to connect to. On some systems and Python versions address is a bytes
# object, on others it's a string. On some it's also prefixed by a null byte which needs to be removed (null
# byte doesn't work with Array).
address = cls.MANAGER.address
if isinstance(address, str):
address = address.encode()
if address[0] == 0:
address = address[1:]
cls.MANAGER_HOST.value = address

# Set host and authkey so other processes know where to connect to
cls.MANAGER_HOST = cls.MANAGER.address
cls.MANAGER_AUTHKEY = bytes(cls.MANAGER._authkey)
cls.MANAGER_STARTED = True

return True

def connect_to_manager(self) -> None:
"""
Connect to the tqdm manager
"""
# Connect to a server. On some systems and Python versions the address is prefixed by a null byte (which was
# stripped when setting the host value, due to restrictions in Array). Address needs to be a string.
address = self.MANAGER_HOST.value.decode()
try:
self.MANAGER = SyncManager(address=address, authkey=b'mpire_tqdm')
self.MANAGER.register('get_tqdm_lock')
self.MANAGER.register('get_tqdm_position_register')
self.MANAGER.connect()
except FileNotFoundError:
address = f"\x00{address}"
self.MANAGER = SyncManager(address=address, authkey=b'mpire_tqdm')
self.MANAGER.register('get_tqdm_lock')
self.MANAGER.register('get_tqdm_position_register')
self.MANAGER.connect()
self.MANAGER = SyncManager(address=self.MANAGER_HOST, authkey=self.MANAGER_AUTHKEY)
self.MANAGER.register('get_tqdm_lock')
self.MANAGER.register('get_tqdm_position_register')
self.MANAGER.connect()

@classmethod
def stop_manager(cls) -> None:
Expand All @@ -188,8 +173,9 @@ def stop_manager(cls) -> None:
"""
cls.MANAGER.shutdown()
cls.MANAGER = None
cls.MANAGER_HOST.value = b''
cls.MANAGER_STARTED.clear()
cls.MANAGER_HOST = None
cls.MANAGER_AUTHKEY = b""
cls.MANAGER_STARTED = False

@staticmethod
def _get_tqdm_lock() -> Lock:
Expand Down Expand Up @@ -223,9 +209,9 @@ def get_connection_details(cls) -> TqdmConnectionDetails:
Obtains the connection details of the tqdm manager. These details are needed to be passed on to child process
when the start method is either forkserver or spawn.
:return: TQDM manager host and whether a manager is started/connected
:return: TQDM manager host, authkey, and whether a manager is started/connected
"""
return cls.MANAGER_HOST.value, cls.MANAGER_STARTED.is_set()
return cls.MANAGER_HOST, cls.MANAGER_AUTHKEY, cls.MANAGER_STARTED

@classmethod
def set_connection_details(cls, tqdm_connection_details: TqdmConnectionDetails) -> None:
Expand All @@ -234,8 +220,8 @@ def set_connection_details(cls, tqdm_connection_details: TqdmConnectionDetails)
:param tqdm_connection_details: TQDM manager host, and whether a manager is started/connected
"""
tqdm_manager_host, tqdm_manager_started = tqdm_connection_details
if not cls.MANAGER_STARTED.is_set():
cls.MANAGER_HOST.value = tqdm_manager_host
if tqdm_manager_started:
cls.MANAGER_STARTED.set()
tqdm_manager_host, tqdm_manager_authkey, tqdm_manager_started = tqdm_connection_details
if not cls.MANAGER_STARTED:
cls.MANAGER_HOST = tqdm_manager_host
cls.MANAGER_AUTHKEY = tqdm_manager_authkey
cls.MANAGER_STARTED = tqdm_manager_started
10 changes: 0 additions & 10 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ def test_n_tasks(self):
"""
Should raise when wrong parameter values are used
"""
pool_params = WorkerPoolParams(None, None)

# Get number of tasks
with self.subTest('get n_tasks', iterable_of_args=range(100)):
n_tasks, *_ = self.check_map_parameters_func(iterable_of_args=range(100), iterable_len=None)
Expand Down Expand Up @@ -373,14 +371,6 @@ def test_worker_lifespan(self):
self.assertEqual(args[0], 11)
self.assertDictEqual(kwargs, {"allowed_types": (int,), "none_allowed": True, "min_": 1})

def test_windows_threading_progress_bar_error(self):
"""
Check that when a progress bar is requested on windows when threading is used, a ValueError is raised.
"""
with patch('mpire.params.RUNNING_WINDOWS', side_effect=True), self.assertRaises(ValueError):
pool_params = WorkerPoolParams(2, None, start_method='threading')
self.check_map_parameters_func(pool_params=pool_params, progress_bar=True)

def test_timeout(self):
"""
Check task_timeout, worker_init_timeout, and worker_exit_timeout. Should raise when wrong parameter values are
Expand Down
14 changes: 4 additions & 10 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,16 +1068,10 @@ def test_start_methods(self):
"""
print()
for start_method in TEST_START_METHODS:
with self.subTest(start_method=start_method), \
WorkerPool(n_jobs=2, start_method=start_method) as pool:
# Progress bar on Windows with threading is currently not supported
if RUNNING_WINDOWS and start_method == 'threading':
with self.assertRaises(ValueError):
pool.map(square, self.test_data, progress_bar=True)
else:
results_list = pool.map(square, self.test_data, progress_bar=True)
self.assertIsInstance(results_list, list)
self.assertEqual(self.test_desired_output, results_list)
with self.subTest(start_method=start_method), WorkerPool(n_jobs=2, start_method=start_method) as pool:
results_list = pool.map(square, self.test_data, progress_bar=True)
self.assertIsInstance(results_list, list)
self.assertEqual(self.test_desired_output, results_list)


class KeepAliveTest(unittest.TestCase):
Expand Down

0 comments on commit 62d4530

Please sign in to comment.