Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel module using joblib for Spark #5924

Merged
merged 13 commits into from
Jun 14, 2023
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
TESTS_REQUIRE = [
# test dependencies
"absl-py",
"joblibspark",
es94129 marked this conversation as resolved.
Show resolved Hide resolved
"pytest",
"pytest-datadir",
"pytest-xdist",
Expand Down
1 change: 1 addition & 0 deletions src/datasets/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parallel import parallel_backend, parallel_map, ParallelBackendConfig # noqa F401
113 changes: 113 additions & 0 deletions src/datasets/parallel/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import contextlib
from multiprocessing import Pool, RLock

import joblib
es94129 marked this conversation as resolved.
Show resolved Hide resolved
from tqdm.auto import tqdm

from ..utils import experimental, logging


logger = logging.get_logger(__name__)


class ParallelBackendConfig:
backend_name = None


@experimental
def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
"""
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
multiprocessing.Pool or joblib for parallelization.

Args:
function (`Callable[[Any], Any]`): Function to be applied to `iterable`.
iterable (`list`, `tuple` or `np.ndarray`): Iterable elements to apply function to.
num_proc (`int`): Number of processes (if no backend specified) or jobs (using joblib).
types (`tuple`): Additional types (besides `dict` values) to apply `function` recursively to their elements.
disable_tqdm (`bool`): Whether to disable the tqdm progressbar.
desc (`str`): Prefix for the tqdm progressbar.
single_map_nested_func (`Callable`): Map function that applies `function` to an element from `iterable`.
Takes a tuple of function, data_struct, types, rank, disable_tqdm, desc as input, where data_struct is an
element of `iterable`, and `rank` is used for progress bar.
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
)

return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)


def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
f"Error dividing inputs iterable among processes. "
f"Total number of objects {len(iterable)}, "
f"length: {sum(len(i[1]) for i in split_kwds)}"
)
Comment on lines +43 to +56
Copy link
Member

@lhoestq lhoestq Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can still be in map_nested, so that the signature of parallel_map could be

def parallel_map(function, iterable, num_proc):

and map_nested would call

parallel_map(_single_map_nested, split_kwds, num_proc=num_proc)

This way it will be easier to start using parallel_map in other places in the code no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so _map_with_joblib would also take split_kwds as input, which is arbitrarily split according to num_proc rather than decided by joblib.
Is there any other places that you are thinking of using parallel_map for? I thought it's just a replacement of the multiprocessing part of map_nested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_jobs is specified to joblib anyway no ? not a big deal anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might leave it like this so that n_jobs=-1 could be used when the user wants to let joblib decide the number of workers / processes etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good idea !


logger.info(
f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}"
)
initargs, initializer = None, None
if not disable_tqdm:
initargs, initializer = (RLock(),), tqdm.set_lock
with Pool(num_proc, initargs=initargs, initializer=initializer) as pool:
mapped = pool.map(single_map_nested_func, split_kwds)
logger.info(f"Finished {num_proc} processes")
mapped = [obj for proc_res in mapped for obj in proc_res]
logger.info(f"Unpacked {len(mapped)} objects")

return mapped


def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change

with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
)


@experimental
@contextlib.contextmanager
def parallel_backend(backend_name: str):
"""
**Experimental.** Configures the parallel backend for parallelized dataset loading, which uses the parallelization
implemented by joblib.

Args:
backend_name (str): Name of backend for parallelization implementation, has to be supported by joblib.

Example usage:
```py
with parallel_backend('spark'):
dataset = load_dataset(..., num_proc=2)
```
"""
ParallelBackendConfig.backend_name = backend_name

if backend_name == "spark":
from joblibspark import register_spark

register_spark()

# TODO: call create_cache_and_write_probe if "download" in steps
# TODO: raise NotImplementedError when Dataset.map etc is called

try:
yield
finally:
ParallelBackendConfig.backend_name = None
2 changes: 2 additions & 0 deletions src/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
"disable_progress_bar",
"enable_progress_bar",
"is_progress_bar_enabled",
"experimental",
]

from .info_utils import VerificationMode
from .logging import disable_progress_bar, enable_progress_bar, is_progress_bar_enabled
from .version import Version
from .experimental import experimental
42 changes: 42 additions & 0 deletions src/datasets/utils/experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Contains utilities to flag a feature as "experimental" in datasets."""
import warnings
from functools import wraps
from typing import Callable


def experimental(fn: Callable) -> Callable:
"""Decorator to flag a feature as experimental.

An experimental feature trigger a warning when used as it might be subject to breaking changes in the future.

Args:
fn (`Callable`):
The function to flag as experimental.

Returns:
`Callable`: The decorated function.

Example:

```python
>>> from datasets.utils import experimental

>>> @experimental
... def my_function():
... print("Hello world!")

>>> my_function()
UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future.
Hello world!
```
"""

@wraps(fn)
def _inner_fn(*args, **kwargs):
warnings.warn(
(f"'{fn.__name__}' is experimental and might be subject to breaking changes in the future."),
UserWarning,
)
return fn(*args, **kwargs)

return _inner_fn
33 changes: 4 additions & 29 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from contextlib import contextmanager
from dataclasses import fields, is_dataclass
from io import BytesIO as StringIO
from multiprocessing import Manager, Pool, RLock
from multiprocessing import Manager
from queue import Empty
from shutil import disk_usage
from types import CodeType, FunctionType
Expand All @@ -43,6 +43,7 @@
from tqdm.auto import tqdm

from .. import config
from ..parallel import parallel_map
from . import logging


Expand Down Expand Up @@ -439,39 +440,13 @@ def map_nested(

if num_proc is None:
num_proc = 1
if num_proc <= 1 or len(iterable) < parallel_min_length:
if num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length:
mapped = [
_single_map_nested((function, obj, types, None, True, None))
for obj in logging.tqdm(iterable, disable=disable_tqdm, desc=desc)
]
else:
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
f"Error dividing inputs iterable among processes. "
f"Total number of objects {len(iterable)}, "
f"length: {sum(len(i[1]) for i in split_kwds)}"
)

logger.info(
f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}"
)
initargs, initializer = None, None
if not disable_tqdm:
initargs, initializer = (RLock(),), tqdm.set_lock
with Pool(num_proc, initargs=initargs, initializer=initializer) as pool:
mapped = pool.map(_single_map_nested, split_kwds)
logger.info(f"Finished {num_proc} processes")
mapped = [obj for proc_res in mapped for obj in proc_res]
logger.info(f"Unpacked {len(mapped)} objects")
mapped = parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, _single_map_nested)

if isinstance(data_struct, dict):
return dict(zip(data_struct.keys(), mapped))
Expand Down
17 changes: 17 additions & 0 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import unittest
import warnings

from datasets.utils import experimental


@experimental
def dummy_function():
return "success"


class TestExperimentalFlag(unittest.TestCase):
def test_experimental_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.assertEqual(dummy_function(), "success")
self.assertEqual(len(w), 1)
51 changes: 51 additions & 0 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from datasets.parallel import ParallelBackendConfig, parallel_backend
from datasets.utils.py_utils import map_nested

from .utils import require_dill_gt_0_3_2, require_joblibspark, require_not_windows


def add_one(i): # picklable for multiprocessing
return i + 1
es94129 marked this conversation as resolved.
Show resolved Hide resolved


@require_dill_gt_0_3_2
@require_joblibspark
es94129 marked this conversation as resolved.
Show resolved Hide resolved
@require_not_windows
def test_parallel_backend_input():
with parallel_backend("spark"):
assert ParallelBackendConfig.backend_name == "spark"

lst = [1, 2, 3]
with pytest.raises(ValueError):
with parallel_backend("unsupported backend"):
map_nested(add_one, lst, num_proc=2)

with pytest.raises(ValueError):
with parallel_backend("unsupported backend"):
map_nested(add_one, lst, num_proc=-1)


@require_dill_gt_0_3_2
@require_joblibspark
@require_not_windows
@pytest.mark.parametrize("num_proc", [2, -1])
def test_parallel_backend_map_nested(num_proc):
s1 = [1, 2]
s2 = {"a": 1, "b": 2}
s3 = {"a": [1, 2], "b": [3, 4]}
s4 = {"a": {"1": 1}, "b": 2}
s5 = {"a": 1, "b": 2, "c": 3, "d": 4}
expected_map_nested_s1 = [2, 3]
expected_map_nested_s2 = {"a": 2, "b": 3}
expected_map_nested_s3 = {"a": [2, 3], "b": [4, 5]}
expected_map_nested_s4 = {"a": {"1": 2}, "b": 3}
expected_map_nested_s5 = {"a": 2, "b": 3, "c": 4, "d": 5}

with parallel_backend("spark"):
assert map_nested(add_one, s1, num_proc=num_proc) == expected_map_nested_s1
assert map_nested(add_one, s2, num_proc=num_proc) == expected_map_nested_s2
assert map_nested(add_one, s3, num_proc=num_proc) == expected_map_nested_s3
assert map_nested(add_one, s4, num_proc=num_proc) == expected_map_nested_s4
assert map_nested(add_one, s5, num_proc=num_proc) == expected_map_nested_s5
2 changes: 1 addition & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Foo:
)
def test_map_nested_num_proc(iterable_length, num_proc, expected_num_proc):
with patch("datasets.utils.py_utils._single_map_nested") as mock_single_map_nested, patch(
"datasets.utils.py_utils.Pool"
"datasets.parallel.parallel.Pool"
) as mock_multiprocessing_pool:
data_struct = {f"{i}": i for i in range(iterable_length)}
_ = map_nested(lambda x: x + 10, data_struct, num_proc=num_proc, parallel_min_length=16)
Expand Down
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,21 @@ def require_pyspark(test_case):
return test_case


def require_joblibspark(test_case):
"""
Decorator marking a test that requires joblibspark.

These tests are skipped when pyspark isn't installed.

"""
try:
import joblibspark # noqa F401
except ImportError:
return unittest.skip("test requires joblibspark")(test_case)
else:
return test_case


def slow(test_case):
"""
Decorator marking a test as slow.
Expand Down