Skip to content

Commit

Permalink
feat: 🎸 use row groups to reduce the response time
Browse files Browse the repository at this point in the history
based on @lhoestq implementation in
https://huggingface.co/spaces/lhoestq/datasets-explorer/blob/main/app.py

Still a POC. We are querying datasets-server.huggingface.co (hardcoded)
to get the list of parquet files.
  • Loading branch information
severo committed Mar 1, 2023
1 parent 7184486 commit 16de983
Showing 1 changed file with 144 additions and 35 deletions.
179 changes: 144 additions & 35 deletions services/api/src/api/routes/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,154 @@
# Copyright 2022 The HuggingFace Authors.

import logging
from http import HTTPStatus
from typing import Optional
from functools import lru_cache, partial

import hffs
# from http import HTTPStatus
from typing import Any, List, Optional, Tuple

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import requests
from hffs.fs import HfFileSystem
from libcommon.processing_graph import ProcessingStep
from libcommon.simple_cache import DoesNotExist, get_response

# from libcommon.simple_cache import DoesNotExist
from starlette.requests import Request
from starlette.responses import Response
from tqdm.contrib.concurrent import thread_map

from api.authentication import auth_check
from api.utils import (
from api.utils import ( # ResponseNotFoundError,
ApiCustomError,
Endpoint,
InvalidParameterError,
MissingRequiredParameterError,
ResponseNotFoundError,
UnexpectedError,
are_valid_parameters,
get_json_api_error_response,
get_json_error_response,
get_json_ok_response,
)

MAX_ROWS = 100

PARQUET_REVISION = "refs/convert/parquet"

# TODO: manage private/gated datasets


class FileSystemError(Exception):
pass


class ParquetResponseError(Exception):
pass


# TODO: how to invalidate the cache when the parquet branch is created or deleted?
@lru_cache(maxsize=128)
def get_parquet_fs(dataset: str) -> HfFileSystem:
"""Get the parquet filesystem for a dataset.
The parquet files are stored in a separate branch of the dataset repository (see PARQUET_REVISION)
Args:
dataset (str): The dataset name.
Returns:
HfFileSystem: The parquet filesystem.
"""
return HfFileSystem(dataset, repo_type="dataset", revision=PARQUET_REVISION)


# RowGroupReaderBase = Callable[[], pa.Table]
# RowGroupReader = Union[RowGroupReaderBase, partial[RowGroupReaderBase]]
RowGroupReader = partial[Any]


@lru_cache(maxsize=128)
def index(parquet_cache_kind: str, dataset: str, config: str, split: str) -> Tuple[Any, List[RowGroupReader]]:
# get the list of parquet files
# try:
# result = get_response(kind=parquet_cache_kind, dataset=dataset)
# except DoesNotExist as e:
# # add "check_in_process" ...
# raise ResponseNotFoundError("Not found.") from e
try:
response = requests.get(f"https://datasets-server.huggingface.co/parquet?dataset={dataset}")
response.raise_for_status()
result = response.json()
except Exception as e:
raise ParquetResponseError("Could not get the list of parquet files.") from e
sources = sorted(
f"{config}/{parquet_file['filename']}"
for parquet_file in result["parquet_files"]
if parquet_file["split"] == split and parquet_file["config"] == config
)
logging.debug(f"Found {len(sources)} parquet files for split {split}: {sources}")
if not sources:
raise ParquetResponseError("No parquet files found.")
fs = get_parquet_fs(dataset)
desc = f"{dataset}/{config}/{split}"
try:
parquet_files: List[pq.ParquetFile] = thread_map(
partial(pq.ParquetFile, filesystem=fs), sources, desc=desc, unit="pq"
)
# parquet_files: List[pq.ParquetFile] = [pq.ParquetFile(source=source, filesystem=fs) for source in sources]
except Exception as e:
raise FileSystemError(f"Could not read the parquet files: {e}") from e
# features = Features.from_arrow_schema(all_pf[0].schema.to_arrow_schema())
# columns = [
# col
# for col in features
# if all(bad_type not in str(features[col]) for bad_type in ["Image(", "Audio(", "'binary'"])
# ]
columns = None
# info = (
# ""
# if len(columns) == len(features)
# else f"Some columns are not supported yet: {sorted(set(features) - set(columns))}"
# )
row_group_offsets = np.cumsum(
[
parquet_file.metadata.row_group(group_id).num_rows
for parquet_file in parquet_files
for group_id in range(parquet_file.metadata.num_row_groups)
]
)
row_group_readers = [
partial(parquet_file.read_row_group, i=group_id, columns=columns)
for parquet_file in parquet_files
for group_id in range(parquet_file.metadata.num_row_groups)
]
return row_group_offsets, row_group_readers


def query(offset: int, length: int, row_group_offsets: Any, row_group_readers: List[RowGroupReader]) -> pa.Table:
"""Query the parquet files
Note that this implementation will always read one row group, to get the list of columns and always have the same
schema, even if the requested rows are invalid (out of range).
Args:
offset (int): The first row to read.
length (int): The number of rows to read.
row_group_offsets (Any): The row group offsets. See index().
row_group_readers (List[RowGroupReader]): The row group readers. See index().
Returns:
pa.Table: The requested rows.
"""
if (len(row_group_offsets) == 0) or (len(row_group_readers) == 0):
raise ParquetResponseError("No parquet files found.")
last_row_in_parquet = row_group_offsets[-1] - 1
first_row = min(offset, last_row_in_parquet)
last_row = min(offset, offset + length - 1, last_row_in_parquet)
first_row_group_id, last_row_group_id = np.searchsorted(row_group_offsets, [first_row, last_row], side="right")
pa_table = pa.concat_tables([row_group_readers[i]() for i in range(first_row_group_id, last_row_group_id + 1)])
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
return pa_table.slice(offset - first_row_in_pa_table, length)


def create_rows_endpoint(
parquet_processing_step: ProcessingStep,
Expand All @@ -49,37 +171,24 @@ async def rows_endpoint(request: Request) -> Response:
if length < 0:
raise InvalidParameterError("Length must be positive")
if length > MAX_ROWS:
raise InvalidParameterError("Length must be less than or equal to {MAX_ROWS}")
logging.info("/rows, dataset={dataset}, config={config}, split={split}, offset={offset}, length={length}}")
raise InvalidParameterError(f"Length must be less than or equal to {MAX_ROWS}")
logging.info(f"/rows, dataset={dataset}, config={config}, split={split}, offset={offset}, length={length}")

# if auth_check fails, it will raise an exception that will be caught below
auth_check(dataset, external_auth_url=external_auth_url, request=request)
try:
# get the list of parquet files
result = get_response(kind=parquet_processing_step.cache_kind, dataset=dataset)
content = result["content"]
http_status = result["http_status"]
error_code = result["error_code"]
if http_status != HTTPStatus.OK:
# TODO: send a proper error message, saying that the parquet files cannot be created
return get_json_error_response(
content=content, status_code=http_status, max_age=max_age_short, error_code=error_code
)
parquet_files_paths = [
f"{config}/{parquet_file['filename']}"
for parquet_file in content["parquet_files"]
if parquet_file["split"] == split and parquet_file["config"] == config
]
# get the list of the first 100 rows
REVISION = "refs/convert/parquet"
fs = hffs.HfFileSystem(dataset, repo_type="dataset", revision=REVISION)
parquet_dataset = pq.ParquetDataset(parquet_files_paths, filesystem=fs)
table = parquet_dataset.read()
rows = table.slice(offset=offset, length=length).to_pylist()
return get_json_ok_response(content={"rows": rows}, max_age=max_age_long)
except DoesNotExist as e:
# add "check_in_process" ...
raise ResponseNotFoundError("Not found.") from e
row_group_offsets, row_group_readers = index(
parquet_cache_kind=parquet_processing_step.cache_kind, dataset=dataset, config=config, split=split
)
pa_table = query(
offset=offset,
length=length,
row_group_offsets=row_group_offsets,
row_group_readers=row_group_readers,
)
# TODO: ignore, or transform, some of the cells (e.g. audio or image)
rows = pa_table.to_pylist()
# TODO: add features?
return get_json_ok_response(content={"rows": rows}, max_age=max_age_long)
except ApiCustomError as e:
return get_json_api_error_response(error=e, max_age=max_age_short)
except Exception as e:
Expand Down

0 comments on commit 16de983

Please sign in to comment.