Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded downloads #6794

Merged
merged 13 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@
os.environ.get("HF_UPDATE_DOWNLOAD_COUNTS", "AUTO").upper() in ENV_VARS_TRUE_AND_AUTO_VALUES
)

# For downloads and to check remote files metadata
HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16

# Remote dataset scripts support
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1")
HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = (
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,10 @@ def _get_single_origin_metadata(

def _get_origin_metadata(
data_files: List[str],
max_workers=64,
download_config: Optional[DownloadConfig] = None,
max_workers: Optional[int] = None,
) -> Tuple[str]:
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
data_files,
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class DownloadConfig:
Key/value pairs to be passed on to the dataset file-system backend, if any.
download_desc (`str`, *optional*):
A description to be displayed alongside with the progress bar while downloading the files.
disable_tqdm (`bool`, defaults to `False`):
Whether to disable the individual files download progress bar
"""

cache_dir: Optional[Union[str, Path]] = None
Expand All @@ -78,6 +80,7 @@ class DownloadConfig:
ignore_url_params: bool = False
storage_options: Dict[str, Any] = field(default_factory=dict)
download_desc: Optional[str] = None
disable_tqdm: bool = False

def __post_init__(self, use_auth_token):
if use_auth_token != "deprecated":
Expand Down
54 changes: 50 additions & 4 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import enum
import io
import multiprocessing
import os
import posixpath
import tarfile
Expand All @@ -27,6 +28,10 @@
from itertools import chain
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

import fsspec
from fsspec.core import url_to_fs
from tqdm.contrib.concurrent import thread_map

from .. import config
from ..utils import tqdm as hf_tqdm
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
Expand All @@ -39,7 +44,7 @@
url_or_path_join,
)
from ..utils.info_utils import get_size_checksum_dict
from ..utils.logging import get_logger
from ..utils.logging import get_logger, tqdm
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from ..utils.track import TrackedIterable, tracked_str
from .download_config import DownloadConfig
Expand Down Expand Up @@ -427,7 +432,7 @@ def download(self, url_or_urls):
if download_config.download_desc is None:
download_config.download_desc = "Downloading data"

download_func = partial(self._download, download_config=download_config)
download_func = partial(self._download_batched, download_config=download_config)

start_time = datetime.now()
with stack_multiprocessing_download_progress_bars():
Expand All @@ -437,6 +442,8 @@ def download(self, url_or_urls):
map_tuple=True,
num_proc=download_config.num_proc,
desc="Downloading data files",
batched=True,
batch_size=-1,
)
duration = datetime.now() - start_time
logger.info(f"Downloading took {duration.total_seconds() // 60} min")
Expand All @@ -451,7 +458,46 @@ def download(self, url_or_urls):

return downloaded_path_or_paths.data

def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str:
def _download_batched(
self,
url_or_filenames: List[str],
download_config: DownloadConfig,
) -> List[str]:
if len(url_or_filenames) >= 16:
download_config = download_config.copy()
download_config.disable_tqdm = True
download_func = partial(self._download_single, download_config=download_config)

fs: fsspec.AbstractFileSystem
fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options)
size = 0
try:
size = fs.info(path).get("size", 0)
except Exception:
pass
max_workers = (
config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1
) # enable multithreading if files are small

return thread_map(
download_func,
url_or_filenames,
desc=download_config.download_desc or "Downloading",
unit="files",
position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
max_workers=max_workers,
tqdm_class=tqdm,
)
else:
return [
self._download_single(url_or_filename, download_config=download_config)
for url_or_filename in url_or_filenames
]

def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
if is_relative_path(url_or_filename):
# append the relative path to the base_path
Expand Down Expand Up @@ -539,7 +585,7 @@ def extract(self, path_or_paths, num_proc="deprecated"):
)
download_config = self.download_config.copy()
download_config.extract_compressed_file = True
extract_func = partial(self._download, download_config=download_config)
extract_func = partial(self._download_single, download_config=download_config)
extracted_paths = map_nested(
extract_func,
path_or_paths,
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,10 +1002,10 @@ def download(self, url_or_urls):
>>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
```
"""
url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True)
url_or_urls = map_nested(self._download_single, url_or_urls, map_tuple=True)
return url_or_urls

def _download(self, urlpath: str) -> str:
def _download_single(self, urlpath: str) -> str:
urlpath = str(urlpath)
if is_relative_path(urlpath):
# append the relative path to the base_path
Expand Down
21 changes: 14 additions & 7 deletions src/datasets/parallel/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ParallelBackendConfig:


@experimental
def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func):
"""
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
multiprocessing.Pool or joblib for parallelization.
Expand All @@ -32,21 +32,25 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
)

return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)
return _map_with_joblib(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
)


def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def _map_with_multiprocessing_pool(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
):
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))
split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
Expand All @@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
return mapped


def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def _map_with_joblib(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change
import joblib

with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None))
for obj in iterable
)


Expand Down
22 changes: 19 additions & 3 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def cached_path(
ignore_url_params=download_config.ignore_url_params,
storage_options=download_config.storage_options,
download_desc=download_config.download_desc,
disable_tqdm=download_config.disable_tqdm,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
Expand Down Expand Up @@ -335,7 +336,7 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
super().__init__(tqdm_kwargs, *args, **kwargs)


def fsspec_get(url, temp_file, storage_options=None, desc=None):
def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, path = url_to_fs(url, **(storage_options or {}))
callback = TqdmCallback(
Expand All @@ -347,6 +348,7 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None):
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
"disable": disable_tqdm,
}
)
fs.get_file(path, temp_file.name, callback=callback)
Expand All @@ -373,7 +375,16 @@ def ftp_get(url, temp_file, timeout=10.0):


def http_get(
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
url,
temp_file,
proxies=None,
resume_size=0,
headers=None,
cookies=None,
timeout=100.0,
max_retries=0,
desc=None,
disable_tqdm=False,
) -> Optional[requests.Response]:
headers = dict(headers) if headers is not None else {}
headers["user-agent"] = get_datasets_user_agent(user_agent=headers.get("user-agent"))
Expand Down Expand Up @@ -405,6 +416,7 @@ def http_get(
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
disable=disable_tqdm,
) as progress:
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
Expand Down Expand Up @@ -464,6 +476,7 @@ def get_from_cache(
ignore_url_params=False,
storage_options=None,
download_desc=None,
disable_tqdm=False,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
Expand Down Expand Up @@ -629,7 +642,9 @@ def temp_file_manager(mode="w+b"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
fsspec_get(
url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
)
else:
http_get(
url,
Expand All @@ -640,6 +655,7 @@ def temp_file_manager(mode="w+b"):
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
disable_tqdm=disable_tqdm,
)

logger.info(f"storing {url} in cache at {cache_path}")
Expand Down
Loading
Loading