-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This function allows you to create a table from any dataset on HuggingFace Hub. It's an easy way to load data for testing and experimentation. The hub contains 250k+ datasets. We use a streaming version of the API to allow for memory-efficient processing and to optimize the case where you want less than the entire dataset. On-disk caching should be minimal using streaming.
- Loading branch information
Showing
12 changed files
with
518 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Load dataset from Hugging Face | ||
|
||
The `ai.load_dataset` function allows you to load datasets from Hugging Face's datasets library directly into your PostgreSQL database. | ||
|
||
## Example Usage | ||
|
||
```sql | ||
select ai.load_dataset('squad'); | ||
|
||
select * from squad limit 10; | ||
``` | ||
|
||
## Parameters | ||
| Name | Type | Default | Required | Description | | ||
|---------------|---------|-------------|----------|----------------------------------------------------------------------------------------------------| | ||
| name | text | - | ✔ | The name of the dataset on Hugging Face (e.g., 'squad', 'glue', etc.) | | ||
| config_name | text | - | ✖ | The specific configuration of the dataset to load. See [Hugging Face documentation](https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations) for more information. | | ||
| split | text | - | ✖ | The split of the dataset to load (e.g., 'train', 'test', 'validation'). Defaults to all splits. | | ||
| schema_name | text | 'public' | ✖ | The PostgreSQL schema where the table will be created | | ||
| table_name | text | - | ✖ | The name of the table to create. If null, will use the dataset name | | ||
| if_table_exists| text | 'error' | ✖ | Behavior when table exists: 'error' (raise error), 'append' (add rows), 'drop' (drop table and recreate) | | ||
| field_types | jsonb | - | ✖ | Custom PostgreSQL data types for columns as a JSONB dictionary from name to type. | | ||
| batch_size | int | 5000 | ✖ | Number of rows to insert in each batch | | ||
| max_batches | int | null | ✖ | Maximum number of batches to load. Null means load all | | ||
| kwargs | jsonb | - | ✖ | Additional arguments passed to the Hugging Face dataset loading function | | ||
|
||
## Returns | ||
|
||
Returns the number of rows loaded into the database (bigint). | ||
|
||
## Examples | ||
|
||
1. Basic usage - Load the entire 'squad' dataset: | ||
|
||
```sql | ||
SELECT ai.load_dataset('squad'); | ||
``` | ||
|
||
The data is loaded into a table named `squad`. | ||
|
||
2. Load a small subset of the 'squad' dataset: | ||
|
||
```sql | ||
SELECT ai.load_dataset('squad', batch_size => 100, max_batches => 1); | ||
``` | ||
|
||
3. Load specific configuration and split: | ||
|
||
```sql | ||
SELECT ai.load_dataset( | ||
name => 'glue', | ||
config_name => 'mrpc', | ||
split => 'train' | ||
); | ||
``` | ||
|
||
4. Load with custom table name and field types: | ||
|
||
```sql | ||
SELECT ai.load_dataset( | ||
name => 'glue', | ||
config_name => 'mrpc', | ||
table_name => 'mrpc', | ||
field_types => '{"sentence1": "text", "sentence2": "text"}'::jsonb | ||
); | ||
``` | ||
|
||
5. Pre-create the table and load data into it: | ||
|
||
```sql | ||
|
||
CREATE TABLE squad ( | ||
id TEXT, | ||
title TEXT, | ||
context TEXT, | ||
question TEXT, | ||
answers JSONB | ||
); | ||
|
||
SELECT ai.load_dataset( | ||
name => 'squad', | ||
table_name => 'squad', | ||
if_table_exists => 'append' | ||
); | ||
``` | ||
|
||
## Notes | ||
|
||
- The function requires an active internet connection to download datasets from Hugging Face. | ||
- Large datasets may take significant time to load depending on size and connection speed. | ||
- The function automatically maps Hugging Face dataset types to appropriate PostgreSQL data types unless overridden by `field_types`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import json | ||
import datasets | ||
from typing import Optional, Dict, Any | ||
|
||
from .utils import get_guc_value | ||
|
||
GUC_DATASET_CACHE_DIR = "ai.dataset_cache_dir" | ||
|
||
|
||
def get_default_column_type(dtype): | ||
# Default type mapping from dtypes to PostgreSQL types | ||
type_mapping = { | ||
"string": "TEXT", | ||
"dict": "JSONB", | ||
"list": "JSONB", | ||
"int64": "INT8", | ||
"int32": "INT4", | ||
"int16": "INT2", | ||
"int8": "INT2", | ||
"float64": "FLOAT8", | ||
"float32": "FLOAT4", | ||
"float16": "FLOAT4", | ||
"bool": "BOOLEAN", | ||
} | ||
|
||
if dtype.startswith("timestamp"): | ||
return "TIMESTAMPTZ" | ||
else: | ||
return type_mapping.get(dtype.lower(), "TEXT") | ||
|
||
|
||
def get_column_info(dataset, field_types): | ||
# Extract types from features | ||
column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()} | ||
# Prepare column types, using field_types if provided, otherwise use inferred types | ||
column_pgtypes = {} | ||
for name, py_type in column_dtypes.items(): | ||
# Use custom type if provided, otherwise map from python type | ||
column_pgtypes[name] = ( | ||
field_types.get(name) | ||
if field_types and name in field_types | ||
else get_default_column_type(str(py_type)) | ||
) | ||
column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys()) | ||
return column_pgtypes, column_dtypes, column_names | ||
|
||
|
||
def create_table( | ||
plpy, name, config_name, schema, table_name, column_types, if_table_exists | ||
): | ||
# Generate default table name if not provided | ||
if table_name is None: | ||
# Handle potential nested dataset names (e.g., "huggingface/dataset") | ||
base_name = name.split("/")[-1] | ||
# Add config name to table name if present | ||
if config_name: | ||
base_name = f"{base_name}_{config_name}" | ||
# Replace any non-alphanumeric characters with underscore | ||
table_name = "".join(c if c.isalnum() else "_" for c in base_name.lower()) | ||
|
||
# Construct fully qualified table name | ||
qualified_table = f'"{schema}"."{table_name}"' | ||
|
||
# Check if table exists | ||
plan = plpy.prepare( | ||
""" | ||
SELECT pg_catalog.to_regclass(pg_catalog.format('%I.%I', $1, $2)) is not null as exists | ||
""", | ||
["text", "text"], | ||
) | ||
result = plan.execute([schema, table_name], 1) | ||
table_exists = result[0]["exists"] | ||
|
||
if table_exists: | ||
if if_table_exists == "drop": | ||
plpy.execute(f"DROP TABLE IF EXISTS {qualified_table}") | ||
elif if_table_exists == "error": | ||
plpy.error( | ||
f"Table {qualified_table} already exists. Set if_table_exists to 'drop' to replace it or 'append' to add to it." | ||
) | ||
elif if_table_exists == "append": | ||
return qualified_table | ||
else: | ||
plpy.error(f"Unsupported if_table_exists value: {if_table_exists}") | ||
|
||
column_type_def = ", ".join( | ||
f'"{name}" {col_type}' for name, col_type in column_types.items() | ||
) | ||
|
||
# Create table | ||
plpy.execute(f"CREATE TABLE {qualified_table} ({column_type_def})") | ||
return qualified_table | ||
|
||
|
||
def load_dataset( | ||
plpy, | ||
# Dataset loading parameters | ||
name: str, | ||
config_name: Optional[str] = None, | ||
split: Optional[str] = None, | ||
# Database target parameters | ||
schema: str = "public", | ||
table_name: Optional[str] = None, | ||
if_table_exists: str = "error", | ||
# Advanced options | ||
field_types: Optional[Dict[str, str]] = None, | ||
batch_size: int = 5000, | ||
max_batches: Optional[int] = None, | ||
# Additional dataset loading options | ||
**kwargs: Dict[str, Any], | ||
) -> int: | ||
""" | ||
Load a dataset into PostgreSQL database using plpy with batch UNNEST operations. | ||
Args: | ||
# Dataset loading parameters | ||
name: Name of the dataset | ||
config_name: Configuration name to load. Some datasets have multiple configurations | ||
(versions or subsets) available. See: https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations | ||
split: Dataset split to load (defaults to all splits) | ||
cache_dir: Directory to cache downloaded datasets (default: None) | ||
# Database target parameters | ||
schema: Target schema name (default: "public") | ||
table_name: Target table name (default: derived from dataset name) | ||
drop_if_exists: If True, drop existing table; if False, error if table exists (default: False) | ||
# Advanced options | ||
field_types: Optional dictionary of field names to PostgreSQL types | ||
batch_size: Number of rows to insert in each batch (default: 5000) | ||
# Additional dataset loading options | ||
**kwargs: Additional keyword arguments passed to datasets.load_dataset() | ||
Returns: | ||
Number of rows loaded | ||
""" | ||
|
||
cache_dir = get_guc_value(plpy, GUC_DATASET_CACHE_DIR, None) | ||
|
||
# Load dataset using Hugging Face datasets library | ||
ds = datasets.load_dataset( | ||
name, config_name, split=split, cache_dir=cache_dir, streaming=True, **kwargs | ||
) | ||
if isinstance(ds, datasets.IterableDatasetDict): | ||
datasetdict = ds | ||
elif isinstance(ds, datasets.IterableDataset): | ||
datasetdict = {split: ds} | ||
else: | ||
plpy.error( | ||
f"Unsupported dataset type: {type(ds)}. Only datasets.IterableDatasetDict and datasets.IterableDataset are supported." | ||
) | ||
|
||
first_dataset = next(iter(datasetdict.values())) | ||
column_pgtypes, column_dtypes, column_names = get_column_info( | ||
first_dataset, field_types | ||
) | ||
qualified_table = create_table( | ||
plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists | ||
) | ||
|
||
# Prepare the UNNEST parameters and INSERT statement once | ||
unnest_params = [] | ||
type_params = [] | ||
for i, (col_name, col_type) in enumerate(column_pgtypes.items(), 1): | ||
unnest_params.append(f"${i}::{col_type}[]") | ||
type_params.append(f"{col_type}[]") | ||
|
||
insert_sql = f""" | ||
INSERT INTO {qualified_table} ({column_names}) | ||
SELECT * FROM unnest({', '.join(unnest_params)}) | ||
""" | ||
insert_plan = plpy.prepare(insert_sql, type_params) | ||
|
||
num_rows = 0 | ||
batch_count = 0 | ||
for split, dataset in datasetdict.items(): | ||
# Process data in batches using dataset iteration | ||
batched_dataset = dataset.batch(batch_size=batch_size) | ||
for batch in batched_dataset: | ||
if max_batches and batch_count >= max_batches: | ||
break | ||
|
||
batch_arrays = [[] for _ in column_dtypes] | ||
for i, (col_name, py_type) in enumerate(column_dtypes.items()): | ||
type_str = str(py_type).lower() | ||
array_values = batch[col_name] | ||
|
||
if type_str in ("dict", "list"): | ||
batch_arrays[i] = [json.dumps(value) for value in array_values] | ||
elif type_str in ("int64", "int32", "int16", "int8"): | ||
batch_arrays[i] = [int(value) for value in array_values] | ||
elif type_str in ("float64", "float32", "float16"): | ||
batch_arrays[i] = [float(value) for value in array_values] | ||
else: | ||
batch_arrays[i] = array_values | ||
|
||
num_rows += len(batch_arrays[0]) | ||
batch_count += 1 | ||
insert_plan.execute(batch_arrays) | ||
return num_rows |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
def get_guc_value(plpy, setting: str, default: str) -> str: | ||
plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"]) | ||
result = plan.execute([setting], 1) | ||
val: str | None = None | ||
if len(result) != 0: | ||
val = result[0]["val"] | ||
if val is None: | ||
val = default | ||
return val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ tiktoken==0.7.0 | |
ollama==0.2.1 | ||
anthropic==0.29.0 | ||
cohere==5.5.8 | ||
backoff==2.2.1 | ||
backoff==2.2.1 | ||
datasets==3.1.0 |
Oops, something went wrong.