Skip to content

Commit

Permalink
feat: enable composable and customizable sampler in PyTorch data load…
Browse files Browse the repository at this point in the history
…er (#1900)

* Provide a set of composable Sampler that works with lance dataset 
* New `ruff` made a bunch of format changes
  • Loading branch information
eddyxu authored Feb 2, 2024
1 parent a0104bd commit 5407db8
Show file tree
Hide file tree
Showing 27 changed files with 657 additions and 453 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ jobs:
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Lint Python
run: |
ruff format --check python
ruff check python
ruff format --preview --check python
ruff check --preview python
- name: Install dependencies
run: |
sudo apt update
Expand Down
10 changes: 5 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setup(app):
# -- Project information -----------------------------------------------------

project = "Lance"
copyright = "2023, Lance Developer"
copyright = "2024, Lance Developer"
author = "Lance Developer"


Expand Down Expand Up @@ -61,10 +61,10 @@ def setup(app):
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]

html_favicon = '_static/favicon_64x64.png'
html_favicon = "_static/favicon_64x64.png"
# html_logo = "_static/high-res-icon.png"
html_theme_options = {
"source_url": 'https://github.com/lancedb/lance',
"source_icon": "github"
"source_url": "https://github.com/lancedb/lance",
"source_icon": "github",
}
html_css_files = ['custom.css']
html_css_files = ["custom.css"]
3 changes: 2 additions & 1 deletion docs/integrations/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Integrations
------------

.. toctree::
:maxdepth: 2

Huggingface <./huggingface>
Tensorflow <./tensorflow>
Tensorflow <./tensorflow>
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ benchmarks = ["pytest-benchmark"]
torch = ["torch"]

[tool.ruff]
select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001", "B019"]
lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001", "B019"]

[tool.mypy]
python_version = "3.11"
Expand Down
1 change: 1 addition & 0 deletions python/python/benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
For configuration that is shared between tests and benchmarks, see ../conftest.py
"""

from pathlib import Path

import pytest
Expand Down
50 changes: 23 additions & 27 deletions python/python/benchmarks/test_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,36 @@

# Mostly vector data. One id column, a caption, and an embedding vector
def create_captioned_image_data(num_bytes):
schema = pa.schema(
[
pa.field("int32", pa.int32()),
pa.field("text", pa.utf8()),
pa.field("vector", pa.list_(pa.float32(), N_DIMS)),
]
)
schema = pa.schema([
pa.field("int32", pa.int32()),
pa.field("text", pa.utf8()),
pa.field("vector", pa.list_(pa.float32(), N_DIMS)),
])
return schema, rand_batches(
schema, num_batches=8, batch_size_bytes=int(num_bytes / 8)
)


# Purely scalar data (schema based on TPC-H lineitem table)
def create_scalar_data(num_bytes):
schema = pa.schema(
[
pa.field("l_orderkey", pa.int64()),
pa.field("l_partkey", pa.int64()),
pa.field("l_suppkey", pa.int64()),
pa.field("l_linenumber", pa.int64()),
pa.field("l_quantity", pa.float64()),
pa.field("l_extendedprice", pa.float64()),
pa.field("l_discount", pa.float64()),
pa.field("l_tax", pa.float64()),
pa.field("l_returnflag", pa.utf8()),
pa.field("l_linestatus", pa.utf8()),
pa.field("l_shipdate", pa.date32()),
pa.field("l_commitdate", pa.date32()),
pa.field("l_receiptdate", pa.date32()),
pa.field("l_shipinstruct", pa.utf8()),
pa.field("l_shipmode", pa.utf8()),
pa.field("l_comment", pa.utf8()),
]
)
schema = pa.schema([
pa.field("l_orderkey", pa.int64()),
pa.field("l_partkey", pa.int64()),
pa.field("l_suppkey", pa.int64()),
pa.field("l_linenumber", pa.int64()),
pa.field("l_quantity", pa.float64()),
pa.field("l_extendedprice", pa.float64()),
pa.field("l_discount", pa.float64()),
pa.field("l_tax", pa.float64()),
pa.field("l_returnflag", pa.utf8()),
pa.field("l_linestatus", pa.utf8()),
pa.field("l_shipdate", pa.date32()),
pa.field("l_commitdate", pa.date32()),
pa.field("l_receiptdate", pa.date32()),
pa.field("l_shipinstruct", pa.utf8()),
pa.field("l_shipmode", pa.utf8()),
pa.field("l_comment", pa.utf8()),
])
return schema, rand_batches(
schema, num_batches=8, batch_size_bytes=int(num_bytes / 8)
)
Expand Down
48 changes: 22 additions & 26 deletions python/python/benchmarks/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,28 @@ def test_scan_integer(tmp_path: Path, benchmark, array_factory):
@pytest.fixture(scope="module")
def sample_dataset(tmpdir_factory):
tmp_path = Path(tmpdir_factory.mktemp("data"))
table = pa.table(
{
"i": pa.array(range(NUM_ROWS), type=pa.int32()),
"f": pc.random(NUM_ROWS).cast(pa.float32()),
"s": pa.array(
[random.choice(["hello", "world", "today"]) for _ in range(NUM_ROWS)],
type=pa.string(),
),
"fsl": pa.FixedSizeListArray.from_arrays(
pc.random(NUM_ROWS * 128).cast(pa.float32()), 128
),
"blob": pa.array(
[
random.choice(
[
random.randbytes(100 * 1024),
random.randbytes(100 * 1024),
random.randbytes(100 * 1024),
]
)
for _ in range(NUM_ROWS)
],
type=pa.binary(),
),
}
)
table = pa.table({
"i": pa.array(range(NUM_ROWS), type=pa.int32()),
"f": pc.random(NUM_ROWS).cast(pa.float32()),
"s": pa.array(
[random.choice(["hello", "world", "today"]) for _ in range(NUM_ROWS)],
type=pa.string(),
),
"fsl": pa.FixedSizeListArray.from_arrays(
pc.random(NUM_ROWS * 128).cast(pa.float32()), 128
),
"blob": pa.array(
[
random.choice([
random.randbytes(100 * 1024),
random.randbytes(100 * 1024),
random.randbytes(100 * 1024),
])
for _ in range(NUM_ROWS)
],
type=pa.binary(),
),
})

return lance.write_dataset(table, tmp_path)

Expand Down
1 change: 1 addition & 0 deletions python/python/lance/_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
An internal module for generating Arrow data for use in testing and benchmarking.
"""

import pyarrow as pa

from .lance import datagen
Expand Down
4 changes: 4 additions & 0 deletions python/python/lance/_dataset/sharded_batch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Generator, List, Literal, Union

import lance
Expand Down Expand Up @@ -74,6 +75,9 @@ def __init__(
batch_readahead: int = 8,
with_row_id: bool = False,
):
warnings.warn(
"ShardedBatchIterator is deprecated, use :class:`Sampler` instead",
)
self._rank = rank
self._world_size = world_size
self._batch_size = batch_size
Expand Down
10 changes: 4 additions & 6 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,12 +991,10 @@ def create_scalar_index(

index_type = index_type.upper()
if index_type != "BTREE":
raise NotImplementedError(
(
'Only "BTREE" is supported for ',
f"index_type. Received {index_type}",
)
)
raise NotImplementedError((
'Only "BTREE" is supported for ',
f"index_type. Received {index_type}",
))

self._ds.create_index([column], index_type, name, replace)

Expand Down
10 changes: 4 additions & 6 deletions python/python/lance/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,10 @@ def cleanup_partial_writes(self, dataset_uri: str) -> int:
fragment_metadata = FragmentMetadata.from_json(
f.read().decode("utf-8")
)
objects.append(
(
fragment_metadata.data_files()[0].path(),
progress_data["multipart_id"],
)
)
objects.append((
fragment_metadata.data_files()[0].path(),
progress_data["multipart_id"],
))

_cleanup_partial_writes(dataset_uri, objects)

Expand Down
Loading

0 comments on commit 5407db8

Please sign in to comment.