Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel module using joblib for Spark #5924

Merged
merged 13 commits into from
Jun 14, 2023
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
TESTS_REQUIRE = [
# test dependencies
"absl-py",
"joblibspark",
es94129 marked this conversation as resolved.
Show resolved Hide resolved
"pytest",
"pytest-datadir",
"pytest-xdist",
Expand Down
1 change: 1 addition & 0 deletions src/datasets/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parallel import parallel_backend, parallel_map, ParallelBackendConfig
103 changes: 103 additions & 0 deletions src/datasets/parallel/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from multiprocessing import Pool, RLock
from tqdm.auto import tqdm
from typing import List
import contextlib
import joblib

from ..utils import logging

logger = logging.get_logger(__name__)


class ParallelBackendConfig:
backend_name = None
steps = []


def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
"""
Apply a function to iterable elements in parallel, where the implementation uses either multiprocessing.Pool or
es94129 marked this conversation as resolved.
Show resolved Hide resolved
es94129 marked this conversation as resolved.
Show resolved Hide resolved
joblib for parallelization.
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
)

return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)


def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
f"Error dividing inputs iterable among processes. "
f"Total number of objects {len(iterable)}, "
f"length: {sum(len(i[1]) for i in split_kwds)}"
)
Comment on lines +43 to +56
Copy link
Member

@lhoestq lhoestq Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can still be in map_nested, so that the signature of parallel_map could be

def parallel_map(function, iterable, num_proc):

and map_nested would call

parallel_map(_single_map_nested, split_kwds, num_proc=num_proc)

This way it will be easier to start using parallel_map in other places in the code no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so _map_with_joblib would also take split_kwds as input, which is arbitrarily split according to num_proc rather than decided by joblib.
Is there any other places that you are thinking of using parallel_map for? I thought it's just a replacement of the multiprocessing part of map_nested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_jobs is specified to joblib anyway no ? not a big deal anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might leave it like this so that n_jobs=-1 could be used when the user wants to let joblib decide the number of workers / processes etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good idea !


logger.info(
f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}"
)
initargs, initializer = None, None
if not disable_tqdm:
initargs, initializer = (RLock(),), tqdm.set_lock
with Pool(num_proc, initargs=initargs, initializer=initializer) as pool:
mapped = pool.map(single_map_nested_func, split_kwds)
logger.info(f"Finished {num_proc} processes")
mapped = [obj for proc_res in mapped for obj in proc_res]
logger.info(f"Unpacked {len(mapped)} objects")

return mapped


def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change

with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
)


@contextlib.contextmanager
def parallel_backend(backend_name: str, steps: List[str]):
es94129 marked this conversation as resolved.
Show resolved Hide resolved
"""
Configures the parallel backend for parallelized dataset loading, steps including download and prepare.
es94129 marked this conversation as resolved.
Show resolved Hide resolved

Example usage:
```py
with parallel_backend('spark', steps=["download"]):
dataset = load_dataset(..., num_proc=2)
```
"""
if "prepare" in steps:
raise NotImplementedError(
"The 'prepare' step that converts the raw data files to Arrow is not compatible "
"with the parallel_backend context manager yet"
)

ParallelBackendConfig.backend_name = backend_name
ParallelBackendConfig.steps = steps

if backend_name == "spark":
from joblibspark import register_spark

register_spark()

# TODO: call create_cache_and_write_probe if "download" in steps

try:
yield
finally:
ParallelBackendConfig.backend_name = None
ParallelBackendConfig.steps = []
33 changes: 4 additions & 29 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from contextlib import contextmanager
from dataclasses import fields, is_dataclass
from io import BytesIO as StringIO
from multiprocessing import Manager, Pool, RLock
from multiprocessing import Manager
from queue import Empty
from shutil import disk_usage
from types import CodeType, FunctionType
Expand All @@ -44,6 +44,7 @@

from .. import config
from . import logging
from ..parallel import parallel_map


try: # pragma: no branch
Expand Down Expand Up @@ -439,39 +440,13 @@ def map_nested(

if num_proc is None:
num_proc = 1
if num_proc <= 1 or len(iterable) < parallel_min_length:
if num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length:
mapped = [
_single_map_nested((function, obj, types, None, True, None))
for obj in logging.tqdm(iterable, disable=disable_tqdm, desc=desc)
]
else:
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
f"Error dividing inputs iterable among processes. "
f"Total number of objects {len(iterable)}, "
f"length: {sum(len(i[1]) for i in split_kwds)}"
)

logger.info(
f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}"
)
initargs, initializer = None, None
if not disable_tqdm:
initargs, initializer = (RLock(),), tqdm.set_lock
with Pool(num_proc, initargs=initargs, initializer=initializer) as pool:
mapped = pool.map(_single_map_nested, split_kwds)
logger.info(f"Finished {num_proc} processes")
mapped = [obj for proc_res in mapped for obj in proc_res]
logger.info(f"Unpacked {len(mapped)} objects")
mapped = parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, _single_map_nested)

if isinstance(data_struct, dict):
return dict(zip(data_struct.keys(), mapped))
Expand Down
51 changes: 51 additions & 0 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from datasets.parallel import parallel_backend, ParallelBackendConfig
from datasets.utils.py_utils import map_nested

from .utils import require_joblibspark


def add_one(i): # picklable for multiprocessing
return i + 1
es94129 marked this conversation as resolved.
Show resolved Hide resolved


@require_joblibspark
es94129 marked this conversation as resolved.
Show resolved Hide resolved
def test_parallel_backend_input():
with parallel_backend("spark", steps=["downloading"]):
es94129 marked this conversation as resolved.
Show resolved Hide resolved
assert ParallelBackendConfig.backend_name == "spark"

with pytest.raises(NotImplementedError):
with parallel_backend("spark", steps=["downloading", "prepare"]):
pass

lst = [1, 2, 3]
with pytest.raises(ValueError):
with parallel_backend("unsupported backend", steps=["downloading"]):
map_nested(add_one, lst, num_proc=2)

with pytest.raises(ValueError):
with parallel_backend("unsupported backend", steps=["downloading"]):
map_nested(add_one, lst, num_proc=-1)


@require_joblibspark
@pytest.mark.parametrize("num_proc", [2, -1])
def test_parallel_backend_map_nested(num_proc):
s1 = [1, 2]
s2 = {"a": 1, "b": 2}
s3 = {"a": [1, 2], "b": [3, 4]}
s4 = {"a": {"1": 1}, "b": 2}
s5 = {"a": 1, "b": 2, "c": 3, "d": 4}
expected_map_nested_s1 = [2, 3]
expected_map_nested_s2 = {"a": 2, "b": 3}
expected_map_nested_s3 = {"a": [2, 3], "b": [4, 5]}
expected_map_nested_s4 = {"a": {"1": 2}, "b": 3}
expected_map_nested_s5 = {"a": 2, "b": 3, "c": 4, "d": 5}

with parallel_backend("spark", steps=["downloading"]):
assert map_nested(add_one, s1, num_proc=num_proc) == expected_map_nested_s1
assert map_nested(add_one, s2, num_proc=num_proc) == expected_map_nested_s2
assert map_nested(add_one, s3, num_proc=num_proc) == expected_map_nested_s3
assert map_nested(add_one, s4, num_proc=num_proc) == expected_map_nested_s4
assert map_nested(add_one, s5, num_proc=num_proc) == expected_map_nested_s5
2 changes: 1 addition & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Foo:
)
def test_map_nested_num_proc(iterable_length, num_proc, expected_num_proc):
with patch("datasets.utils.py_utils._single_map_nested") as mock_single_map_nested, patch(
"datasets.utils.py_utils.Pool"
"datasets.parallel.parallel.Pool"
) as mock_multiprocessing_pool:
data_struct = {f"{i}": i for i in range(iterable_length)}
_ = map_nested(lambda x: x + 10, data_struct, num_proc=num_proc, parallel_min_length=16)
Expand Down
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,21 @@ def require_pyspark(test_case):
return test_case


def require_joblibspark(test_case):
"""
Decorator marking a test that requires joblibspark.

These tests are skipped when pyspark isn't installed.

"""
try:
import joblibspark # noqa F401
except ImportError:
return unittest.skip("test requires joblibspark")(test_case)
else:
return test_case


def slow(test_case):
"""
Decorator marking a test as slow.
Expand Down