Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better tqdm wrapper #6433

Merged
merged 12 commits into from
Nov 22, 2023
1 change: 1 addition & 0 deletions docs/source/_redirects.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ faiss_and_ea: faiss_es
features: about_dataset_features
using_metrics: how_to_metrics
exploring: access
package_reference/logging_methods: package_reference/utilities
# end of first_section
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@
title: Loading methods
- local: package_reference/table_classes
title: Table Classes
- local: package_reference/logging_methods
title: Logging methods
- local: package_reference/utilities
title: Utilities
- local: package_reference/task_templates
title: Task templates
title: "Reference"
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Logging methods
# Utilities

## Configure logging

🤗 Datasets strives to be transparent and explicit about how it works, but this can be quite verbose at times. We have included a series of logging methods which allow you to easily adjust the level of verbosity of the entire library. Currently the default verbosity of the library is set to `WARNING`.

Expand Down Expand Up @@ -28,10 +30,6 @@ In order from the least to the most verbose (with their corresponding `int` valu
4. `logging.INFO` (int value, 20): reports error, warnings and basic information.
5. `logging.DEBUG` (int value, 10): report all information.

By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.

## Functions

[[autodoc]] datasets.logging.get_verbosity

[[autodoc]] datasets.logging.set_verbosity
Expand All @@ -48,44 +46,13 @@ By default, `tqdm` progress bars will be displayed during dataset download and p

[[autodoc]] datasets.logging.enable_propagation

[[autodoc]] datasets.logging.get_logger

[[autodoc]] datasets.logging.enable_progress_bar

[[autodoc]] datasets.logging.disable_progress_bar

[[autodoc]] datasets.is_progress_bar_enabled

## Levels

### datasets.logging.CRITICAL

datasets.logging.CRITICAL = 50

### datasets.logging.DEBUG

datasets.logging.DEBUG = 10

### datasets.logging.ERROR

datasets.logging.ERROR = 40

### datasets.logging.FATAL

datasets.logging.FATAL = 50

### datasets.logging.INFO

datasets.logging.INFO = 20

### datasets.logging.NOTSET

datasets.logging.NOTSET = 0
## Configure progress bars

### datasets.logging.WARN
By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. You can disable them globally by setting `HF_DATASETS_DISABLE_PROGRESS_BARS`
environment variable. You can also enable/disable them using [`~utils.enable_progress_bars`] and [`~utils.disable_progress_bars`]. If set, the environment variable has priority on the helpers.

datasets.logging.WARN = 30
[[autodoc]] datasets.utils.enable_progress_bars

### datasets.logging.WARNING
[[autodoc]] datasets.utils.disable_progress_bars

datasets.logging.WARNING = 30
[[autodoc]] datasets.utils.are_progress_bars_disabled
13 changes: 5 additions & 8 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
)
from .tasks import TaskTemplate
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.deprecation_utils import deprecated
from .utils.file_utils import _retry, estimate_dataset_size
from .utils.info_utils import is_small_dataset
Expand Down Expand Up @@ -1494,8 +1495,7 @@ def save_to_disk(
dataset_info = asdict(self._info)

shards_done = 0
pbar = logging.tqdm(
disable=not logging.is_progress_bar_enabled(),
pbar = hf_tqdm(
unit=" examples",
total=len(self),
desc=f"Saving the dataset ({shards_done}/{num_shards} shards)",
Expand Down Expand Up @@ -3080,8 +3080,7 @@ def load_processed_shard_from_cache(shard_kwargs):
except NonExistentDatasetError:
pass
if transformed_dataset is None:
with logging.tqdm(
disable=not logging.is_progress_bar_enabled(),
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=desc or "Map",
Expand Down Expand Up @@ -3173,8 +3172,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
with Pool(len(kwargs_per_job)) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")
with logging.tqdm(
disable=not logging.is_progress_bar_enabled(),
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=(desc or "Map") + f" (num_proc={num_proc})",
Expand Down Expand Up @@ -5195,11 +5193,10 @@ def shards_with_embedded_external_files(shards):

uploaded_size = 0
additions = []
for index, shard in logging.tqdm(
for index, shard in hf_tqdm(
enumerate(shards),
desc="Uploading the dataset shards",
total=num_shards,
disable=not logging.is_progress_bar_enabled(),
):
shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
buffer = BytesIO()
Expand Down
10 changes: 4 additions & 6 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .keyhash import DuplicatedKeysError, KeyHasher
from .table import array_cast, array_concat, cast_array_to_feature, embed_table_storage, table_cast
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import hash_url_to_filename
from .utils.py_utils import asdict, first_non_null_value

Expand Down Expand Up @@ -689,9 +690,8 @@ def finalize(self, metrics_query_result: dict):
for metadata in beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list
]
try: # stream conversion
disable = not logging.is_progress_bar_enabled()
num_bytes = 0
for shard in logging.tqdm(shards, unit="shards", disable=disable):
for shard in hf_tqdm(shards, unit="shards"):
with beam.io.filesystems.FileSystems.open(shard) as source:
with beam.io.filesystems.FileSystems.create(
shard.replace(".parquet", ".arrow")
Expand All @@ -706,9 +706,8 @@ def finalize(self, metrics_query_result: dict):
)
local_convert_dir = os.path.join(self._cache_dir, "beam_convert")
os.makedirs(local_convert_dir, exist_ok=True)
disable = not logging.is_progress_bar_enabled()
num_bytes = 0
for shard in logging.tqdm(shards, unit="shards", disable=disable):
for shard in hf_tqdm(shards, unit="shards"):
local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
beam_utils.download_remote_to_local(shard, local_parquet_path)
local_arrow_path = local_parquet_path.replace(".parquet", ".arrow")
Expand All @@ -727,8 +726,7 @@ def finalize(self, metrics_query_result: dict):

def get_parquet_lengths(sources) -> List[int]:
shard_lengths = []
disable = not logging.is_progress_bar_enabled()
for source in logging.tqdm(sources, unit="parquet files", disable=disable):
for source in hf_tqdm(sources, unit="parquet files"):
parquet_file = pa.parquet.ParquetFile(source)
shard_lengths.append(parquet_file.metadata.num_rows)
return shard_lengths
Expand Down
7 changes: 3 additions & 4 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
from .streaming import extend_dataset_builder_for_streaming
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import cached_path, is_remote_url
from .utils.filelock import FileLock
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
Expand Down Expand Up @@ -1526,8 +1527,7 @@ def _prepare_split(
)
num_proc = num_input_shards

pbar = logging.tqdm(
disable=not logging.is_progress_bar_enabled(),
pbar = hf_tqdm(
unit=" examples",
total=split_info.num_examples,
desc=f"Generating {split_info.name} split",
Expand Down Expand Up @@ -1784,8 +1784,7 @@ def _prepare_split(
)
num_proc = num_input_shards

pbar = logging.tqdm(
disable=not logging.is_progress_bar_enabled(),
pbar = hf_tqdm(
unit=" examples",
total=split_info.num_examples,
desc=f"Generating {split_info.name} split",
Expand Down
18 changes: 15 additions & 3 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import importlib
import importlib.metadata
import logging
import os
import platform
from pathlib import Path
from typing import Optional

from packaging import version

from .utils.logging import get_logger


logger = get_logger(__name__)
logger = logging.getLogger(__name__.split(".", 1)[0]) # to avoid circular import from .utils.logging

# Datasets
S3_DATASETS_BUCKET_PREFIX = "https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets"
Expand Down Expand Up @@ -192,6 +192,18 @@
# Offline mode
HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() in ENV_VARS_TRUE_VALUES

# Here, `True` will disable progress bars globally without possibility of enabling it
# programmatically. `False` will enable them without possibility of disabling them.
# If environment variable is not set (None), then the user is free to enable/disable
# them programmatically.
# TL;DR: env variable has priority over code
__HF_DATASETS_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
HF_DATASETS_DISABLE_PROGRESS_BARS: Optional[bool] = (
__HF_DATASETS_DISABLE_PROGRESS_BARS.upper() in ENV_VARS_TRUE_VALUES
if __HF_DATASETS_DISABLE_PROGRESS_BARS is not None
else None
)

# In-memory
DEFAULT_IN_MEMORY_MAX_SIZE = 0 # Disabled
IN_MEMORY_MAX_SIZE = float(os.environ.get("HF_DATASETS_IN_MEMORY_MAX_SIZE", DEFAULT_IN_MEMORY_MAX_SIZE))
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .download.streaming_download_manager import _prepare_path_and_storage_options, xbasename, xjoin
from .splits import Split
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import is_local_path, is_relative_path
from .utils.py_utils import glob_pattern_to_regex, string_to_dict

Expand Down Expand Up @@ -515,9 +516,9 @@ def _get_origin_metadata(
partial(_get_single_origin_metadata, download_config=download_config),
data_files,
max_workers=max_workers,
tqdm_class=logging.tqdm,
tqdm_class=hf_tqdm,
desc="Resolving data files",
disable=len(data_files) <= 16 or not logging.is_progress_bar_enabled(),
disable=len(data_files) <= 16,
)


Expand Down
13 changes: 4 additions & 9 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union

from .. import config
from ..utils import tqdm as hf_tqdm
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
from ..utils.file_utils import cached_path, get_from_cache, hash_url_to_filename, is_relative_path, url_or_path_join
from ..utils.info_utils import get_size_checksum_dict
from ..utils.logging import get_logger, is_progress_bar_enabled, tqdm
from ..utils.logging import get_logger
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from .download_config import DownloadConfig

Expand Down Expand Up @@ -327,18 +328,16 @@ def upload(local_file_path):
uploaded_path_or_paths = map_nested(
lambda local_file_path: upload(local_file_path),
downloaded_path_or_paths,
disable_tqdm=not is_progress_bar_enabled(),
)
return uploaded_path_or_paths

def _record_sizes_checksums(self, url_or_urls: NestedDataStructure, downloaded_path_or_paths: NestedDataStructure):
"""Record size/checksum of downloaded files."""
delay = 5
for url, path in tqdm(
for url, path in hf_tqdm(
list(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())),
delay=delay,
desc="Computing checksums",
disable=not is_progress_bar_enabled(),
):
# call str to support PathLike objects
self._recorded_sizes_checksums[str(url)] = get_size_checksum_dict(
Expand Down Expand Up @@ -373,9 +372,7 @@ def download_custom(self, url_or_urls, custom_download):
def url_to_downloaded_path(url):
return os.path.join(cache_dir, hash_url_to_filename(url))

downloaded_path_or_paths = map_nested(
url_to_downloaded_path, url_or_urls, disable_tqdm=not is_progress_bar_enabled()
)
downloaded_path_or_paths = map_nested(url_to_downloaded_path, url_or_urls)
url_or_urls = NestedDataStructure(url_or_urls)
downloaded_path_or_paths = NestedDataStructure(downloaded_path_or_paths)
for url, path in zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten()):
Expand Down Expand Up @@ -426,7 +423,6 @@ def download(self, url_or_urls):
url_or_urls,
map_tuple=True,
num_proc=download_config.num_proc,
disable_tqdm=not is_progress_bar_enabled(),
desc="Downloading data files",
)
duration = datetime.now() - start_time
Expand Down Expand Up @@ -534,7 +530,6 @@ def extract(self, path_or_paths, num_proc="deprecated"):
partial(cached_path, download_config=download_config),
path_or_paths,
num_proc=download_config.num_proc,
disable_tqdm=not is_progress_bar_enabled(),
desc="Extracting data files",
)
path_or_paths = NestedDataStructure(path_or_paths)
Expand Down
8 changes: 3 additions & 5 deletions src/datasets/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules.csv.csv import Csv
from ..utils import logging
from ..utils import tqdm as hf_tqdm
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader

Expand Down Expand Up @@ -117,10 +117,9 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
written = 0

if self.num_proc is None or self.num_proc == 1:
for offset in logging.tqdm(
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating CSV from Arrow format",
):
csv_str = self._batch_csv((offset, header, index, to_csv_kwargs))
Expand All @@ -129,14 +128,13 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for csv_str in logging.tqdm(
for csv_str in hf_tqdm(
pool.imap(
self._batch_csv,
[(offset, header, index, to_csv_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating CSV from Arrow format",
):
written += file_obj.write(csv_str)
Expand Down
8 changes: 3 additions & 5 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules.json.json import Json
from ..utils import logging
from ..utils import tqdm as hf_tqdm
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader

Expand Down Expand Up @@ -139,25 +139,23 @@ def _write(
written = 0

if self.num_proc is None or self.num_proc == 1:
for offset in logging.tqdm(
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating json from Arrow format",
):
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in logging.tqdm(
for json_str in hf_tqdm(
pool.imap(
self._batch_json,
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating json from Arrow format",
):
written += file_obj.write(json_str)
Expand Down
Loading
Loading