Skip to content

Commit

Permalink
feat: add load_datasets
Browse files Browse the repository at this point in the history
This function allows you to create a table from any dataset on
HuggingFace Hub. It's an easy way to load data for testing and
experimentation. The hub contains 250k+ datasets.

We use a streaming version of the API to allow for memory-efficient
processing and to optimize the case where you want less than
the entire dataset. On-disk caching should be minimal using streaming.
  • Loading branch information
cevian committed Nov 25, 2024
1 parent 1051dcd commit 9b0a932
Show file tree
Hide file tree
Showing 12 changed files with 518 additions and 26 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ For other use cases, first [Install pgai](#installation) in Timescale Cloud, a p
* [Anthropic](./docs/anthropic.md) - configure pgai for Anthropic, then use the model to generate content.
* [Cohere](./docs/cohere.md) - configure pgai for Cohere, then use the model to tokenize, embed, chat complete, classify, and rerank.
- Leverage LLMs for data processing tasks such as classification, summarization, and data enrichment ([see the OpenAI example](/docs/openai.md)).
- Load datasets from Hugging Face into your database with [ai.load_dataset](/docs/load_dataset_from_huggingface.md).



Expand Down Expand Up @@ -167,7 +168,7 @@ You can use pgai to integrate AI from the following providers:
- [Cohere](./docs/cohere.md)
- [Llama 3 (via Ollama)](/docs/ollama.md)
Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs.
Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs. To get started, [load datasets directly from Hugging Face](/docs/load_dataset_from_huggingface.md) into your database.
### Automatically create and sync LLM embeddings for your data
Expand Down
91 changes: 91 additions & 0 deletions docs/load_dataset_from_huggingface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Load dataset from Hugging Face

The `ai.load_dataset` function allows you to load datasets from Hugging Face's datasets library directly into your PostgreSQL database.

## Example Usage

```sql
select ai.load_dataset('squad');

select * from squad limit 10;
```

## Parameters
| Name | Type | Default | Required | Description |
|---------------|---------|-------------|----------|----------------------------------------------------------------------------------------------------|
| name | text | - || The name of the dataset on Hugging Face (e.g., 'squad', 'glue', etc.) |
| config_name | text | - || The specific configuration of the dataset to load. See [Hugging Face documentation](https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations) for more information. |
| split | text | - || The split of the dataset to load (e.g., 'train', 'test', 'validation'). Defaults to all splits. |
| schema_name | text | 'public' || The PostgreSQL schema where the table will be created |
| table_name | text | - || The name of the table to create. If null, will use the dataset name |
| if_table_exists| text | 'error' || Behavior when table exists: 'error' (raise error), 'append' (add rows), 'drop' (drop table and recreate) |
| field_types | jsonb | - || Custom PostgreSQL data types for columns as a JSONB dictionary from name to type. |
| batch_size | int | 5000 || Number of rows to insert in each batch |
| max_batches | int | null || Maximum number of batches to load. Null means load all |
| kwargs | jsonb | - || Additional arguments passed to the Hugging Face dataset loading function |

## Returns

Returns the number of rows loaded into the database (bigint).

## Examples

1. Basic usage - Load the entire 'squad' dataset:

```sql
SELECT ai.load_dataset('squad');
```

The data is loaded into a table named `squad`.

2. Load a small subset of the 'squad' dataset:

```sql
SELECT ai.load_dataset('squad', batch_size => 100, max_batches => 1);
```

3. Load specific configuration and split:

```sql
SELECT ai.load_dataset(
name => 'glue',
config_name => 'mrpc',
split => 'train'
);
```

4. Load with custom table name and field types:

```sql
SELECT ai.load_dataset(
name => 'glue',
config_name => 'mrpc',
table_name => 'mrpc',
field_types => '{"sentence1": "text", "sentence2": "text"}'::jsonb
);
```

5. Pre-create the table and load data into it:

```sql

CREATE TABLE squad (
id TEXT,
title TEXT,
context TEXT,
question TEXT,
answers JSONB
);

SELECT ai.load_dataset(
name => 'squad',
table_name => 'squad',
if_table_exists => 'append'
);
```

## Notes

- The function requires an active internet connection to download datasets from Hugging Face.
- Large datasets may take significant time to load depending on size and connection speed.
- The function automatically maps Hugging Face dataset types to appropriate PostgreSQL data types unless overridden by `field_types`.
201 changes: 201 additions & 0 deletions projects/extension/ai/load_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import json
import datasets
from typing import Optional, Dict, Any

from .utils import get_guc_value

GUC_DATASET_CACHE_DIR = "ai.dataset_cache_dir"


def get_default_column_type(dtype):
# Default type mapping from dtypes to PostgreSQL types
type_mapping = {
"string": "TEXT",
"dict": "JSONB",
"list": "JSONB",
"int64": "INT8",
"int32": "INT4",
"int16": "INT2",
"int8": "INT2",
"float64": "FLOAT8",
"float32": "FLOAT4",
"float16": "FLOAT4",
"bool": "BOOLEAN",
}

if dtype.startswith("timestamp"):
return "TIMESTAMPTZ"
else:
return type_mapping.get(dtype.lower(), "TEXT")


def get_column_info(dataset, field_types):
# Extract types from features
column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()}
# Prepare column types, using field_types if provided, otherwise use inferred types
column_pgtypes = {}
for name, py_type in column_dtypes.items():
# Use custom type if provided, otherwise map from python type
column_pgtypes[name] = (
field_types.get(name)
if field_types and name in field_types
else get_default_column_type(str(py_type))
)
column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys())
return column_pgtypes, column_dtypes, column_names


def create_table(
plpy, name, config_name, schema, table_name, column_types, if_table_exists
):
# Generate default table name if not provided
if table_name is None:
# Handle potential nested dataset names (e.g., "huggingface/dataset")
base_name = name.split("/")[-1]
# Add config name to table name if present
if config_name:
base_name = f"{base_name}_{config_name}"
# Replace any non-alphanumeric characters with underscore
table_name = "".join(c if c.isalnum() else "_" for c in base_name.lower())

# Construct fully qualified table name
qualified_table = f'"{schema}"."{table_name}"'

# Check if table exists
plan = plpy.prepare(
"""
SELECT pg_catalog.to_regclass(pg_catalog.format('%I.%I', $1, $2)) is not null as exists
""",
["text", "text"],
)
result = plan.execute([schema, table_name], 1)
table_exists = result[0]["exists"]

if table_exists:
if if_table_exists == "drop":
plpy.execute(f"DROP TABLE IF EXISTS {qualified_table}")
elif if_table_exists == "error":
plpy.error(
f"Table {qualified_table} already exists. Set if_table_exists to 'drop' to replace it or 'append' to add to it."
)
elif if_table_exists == "append":
return qualified_table
else:
plpy.error(f"Unsupported if_table_exists value: {if_table_exists}")

column_type_def = ", ".join(
f'"{name}" {col_type}' for name, col_type in column_types.items()
)

# Create table
plpy.execute(f"CREATE TABLE {qualified_table} ({column_type_def})")
return qualified_table


def load_dataset(
plpy,
# Dataset loading parameters
name: str,
config_name: Optional[str] = None,
split: Optional[str] = None,
# Database target parameters
schema: str = "public",
table_name: Optional[str] = None,
if_table_exists: str = "error",
# Advanced options
field_types: Optional[Dict[str, str]] = None,
batch_size: int = 5000,
max_batches: Optional[int] = None,
# Additional dataset loading options
**kwargs: Dict[str, Any],
) -> int:
"""
Load a dataset into PostgreSQL database using plpy with batch UNNEST operations.
Args:
# Dataset loading parameters
name: Name of the dataset
config_name: Configuration name to load. Some datasets have multiple configurations
(versions or subsets) available. See: https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations
split: Dataset split to load (defaults to all splits)
cache_dir: Directory to cache downloaded datasets (default: None)
# Database target parameters
schema: Target schema name (default: "public")
table_name: Target table name (default: derived from dataset name)
drop_if_exists: If True, drop existing table; if False, error if table exists (default: False)
# Advanced options
field_types: Optional dictionary of field names to PostgreSQL types
batch_size: Number of rows to insert in each batch (default: 5000)
# Additional dataset loading options
**kwargs: Additional keyword arguments passed to datasets.load_dataset()
Returns:
Number of rows loaded
"""

cache_dir = get_guc_value(plpy, GUC_DATASET_CACHE_DIR, None)

# Load dataset using Hugging Face datasets library
ds = datasets.load_dataset(
name, config_name, split=split, cache_dir=cache_dir, streaming=True, **kwargs
)
if isinstance(ds, datasets.IterableDatasetDict):
datasetdict = ds
elif isinstance(ds, datasets.IterableDataset):
datasetdict = {split: ds}
else:
plpy.error(
f"Unsupported dataset type: {type(ds)}. Only datasets.IterableDatasetDict and datasets.IterableDataset are supported."
)

first_dataset = next(iter(datasetdict.values()))
column_pgtypes, column_dtypes, column_names = get_column_info(
first_dataset, field_types
)
qualified_table = create_table(
plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists
)

# Prepare the UNNEST parameters and INSERT statement once
unnest_params = []
type_params = []
for i, (col_name, col_type) in enumerate(column_pgtypes.items(), 1):
unnest_params.append(f"${i}::{col_type}[]")
type_params.append(f"{col_type}[]")

insert_sql = f"""
INSERT INTO {qualified_table} ({column_names})
SELECT * FROM unnest({', '.join(unnest_params)})
"""
insert_plan = plpy.prepare(insert_sql, type_params)

num_rows = 0
batch_count = 0
for split, dataset in datasetdict.items():
# Process data in batches using dataset iteration
batched_dataset = dataset.batch(batch_size=batch_size)
for batch in batched_dataset:
if max_batches and batch_count >= max_batches:
break

batch_arrays = [[] for _ in column_dtypes]
for i, (col_name, py_type) in enumerate(column_dtypes.items()):
type_str = str(py_type).lower()
array_values = batch[col_name]

if type_str in ("dict", "list"):
batch_arrays[i] = [json.dumps(value) for value in array_values]
elif type_str in ("int64", "int32", "int16", "int8"):
batch_arrays[i] = [int(value) for value in array_values]
elif type_str in ("float64", "float32", "float16"):
batch_arrays[i] = [float(value) for value in array_values]
else:
batch_arrays[i] = array_values

num_rows += len(batch_arrays[0])
batch_count += 1
insert_plan.execute(batch_arrays)
return num_rows
13 changes: 2 additions & 11 deletions projects/extension/ai/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import httpx
from backoff._typing import Details

from .utils import get_guc_value

GUC_SECRETS_MANAGER_URL = "ai.external_functions_executor_url"
GUC_SECRET_ENV_ENABLED = "ai.secret_env_enabled"

Expand Down Expand Up @@ -45,17 +47,6 @@ def get_secret(
return secret


def get_guc_value(plpy, setting: str, default: str) -> str:
plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"])
result = plan.execute([setting], 1)
val: str | None = None
if len(result) != 0:
val = result[0]["val"]
if val is None:
val = default
return val


def check_secret_permissions(plpy, secret_name: str) -> bool:
# check if the user has access to all secrets
plan = plpy.prepare(
Expand Down
9 changes: 9 additions & 0 deletions projects/extension/ai/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def get_guc_value(plpy, setting: str, default: str) -> str:
plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"])
result = plan.execute([setting], 1)
val: str | None = None
if len(result) != 0:
val = result[0]["val"]
if val is None:
val = default
return val
13 changes: 2 additions & 11 deletions projects/extension/ai/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,15 @@
import httpx
from backoff._typing import Details

from .utils import get_guc_value

GUC_VECTORIZER_URL = "ai.external_functions_executor_url"
DEFAULT_VECTORIZER_URL = "http://localhost:8000"

GUC_VECTORIZER_PATH = "ai.external_functions_executor_events_path"
DEFAULT_VECTORIZER_PATH = "/api/v1/events"


def get_guc_value(plpy, setting: str, default: str) -> str:
plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"])
result = plan.execute([setting], 1)
val: str | None = None
if len(result) != 0:
val = result[0]["val"]
if val is None:
val = default
return val


def execute_vectorizer(plpy, vectorizer_id: int) -> None:
plan = plpy.prepare(
"""
Expand Down
2 changes: 1 addition & 1 deletion projects/extension/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def clean_sql() -> None:

def postgres_bin_dir() -> Path:
bin_dir = os.getenv("PG_BIN")
if bin_dir:
if Path(bin_dir).is_dir():
return Path(bin_dir).resolve()
else:
bin_dir = Path(f"/usr/lib/postgresql/{pg_major()}/bin")
Expand Down
2 changes: 1 addition & 1 deletion projects/extension/justfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PG_MAJOR := env("PG_MAJOR", "17")
PG_BIN := "/usr/lib/postgresql/" + PG_MAJOR + "/bin"
PG_BIN := env("PG_BIN", "/usr/lib/postgresql/" + PG_MAJOR + "/bin")

# Show list of recipes
default:
Expand Down
3 changes: 2 additions & 1 deletion projects/extension/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ tiktoken==0.7.0
ollama==0.2.1
anthropic==0.29.0
cohere==5.5.8
backoff==2.2.1
backoff==2.2.1
datasets==3.1.0
Loading

0 comments on commit 9b0a932

Please sign in to comment.