Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel module using joblib for Spark #5924

Merged
merged 13 commits into from
Jun 14, 2023
Prev Previous commit
Next Next commit
Add comments
es94129 committed Jun 6, 2023
commit 4b657cbdf2d17463bcba3ea21104bc9db9096564
21 changes: 18 additions & 3 deletions src/datasets/parallel/parallel.py
Original file line number Diff line number Diff line change
@@ -15,12 +15,16 @@ class ParallelBackendConfig:


def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
"""
Apply a function to iterable elements in parallel, where the implementation uses either multiprocessing.Pool or
es94129 marked this conversation as resolved.
Show resolved Hide resolved
es94129 marked this conversation as resolved.
Show resolved Hide resolved
joblib for parallelization.
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
)

return _map_with_joblib(function, iterable, num_proc, types, single_map_nested_func)
return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)


def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
@@ -55,8 +59,10 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
return mapped


def _map_with_joblib(function, iterable, num_proc, types, single_map_nested_func):
# TODO: take num_proc, tqdm args
def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change

with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
@@ -65,6 +71,15 @@ def _map_with_joblib(function, iterable, num_proc, types, single_map_nested_func

@contextlib.contextmanager
def parallel_backend(backend_name: str, steps: List[str]):
es94129 marked this conversation as resolved.
Show resolved Hide resolved
"""
Configures the parallel backend for parallelized dataset loading, steps including download and prepare.
es94129 marked this conversation as resolved.
Show resolved Hide resolved

Example usage:
```py
with parallel_backend('spark', steps=["download"]):
dataset = load_dataset(..., num_proc=2)
```
"""
if "prepare" in steps:
raise NotImplementedError(
"The 'prepare' step that converts the raw data files to Arrow is not compatible "