Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Arrow stream interface for public API #69

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 63 additions & 19 deletions stac_geoparquet/arrow/_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import itertools
import os
from pathlib import Path
from typing import Any, Iterable, Iterator
from typing import Any, Iterable

import pyarrow as pa

from stac_geoparquet.arrow._batch import StacArrowBatch, StacJsonBatch
from stac_geoparquet.arrow._constants import DEFAULT_JSON_CHUNK_SIZE
from stac_geoparquet.arrow._schema.models import InferredSchema
from stac_geoparquet.arrow._util import batched_iter
from stac_geoparquet.arrow.types import ArrowStreamExportable
from stac_geoparquet.json_reader import read_json_chunked


Expand All @@ -18,7 +20,7 @@ def parse_stac_items_to_arrow(
*,
chunk_size: int = 8192,
schema: pa.Schema | InferredSchema | None = None,
) -> Iterable[pa.RecordBatch]:
) -> pa.RecordBatchReader:
"""
Parse a collection of STAC Items to an iterable of
[`pyarrow.RecordBatch`][pyarrow.RecordBatch].
Expand All @@ -37,23 +39,27 @@ def parse_stac_items_to_arrow(
inference. Defaults to None.

Returns:
an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items.
pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches.
"""
if schema is not None:
if isinstance(schema, InferredSchema):
schema = schema.inner

# If schema is provided, then for better memory usage we parse input STAC items
# to Arrow batches in chunks.
for chunk in batched_iter(items, chunk_size):
yield stac_items_to_arrow(chunk, schema=schema)
batches = (
stac_items_to_arrow(batch, schema=schema)
for batch in batched_iter(items, chunk_size)
)
return pa.RecordBatchReader.from_batches(schema, batches)

else:
# If schema is _not_ provided, then we must convert to Arrow all at once, or
# else it would be possible for a STAC item late in the collection (after the
# first chunk) to have a different schema and not match the schema inferred for
# the first chunk.
yield stac_items_to_arrow(items)
batch = stac_items_to_arrow(items)
return pa.RecordBatchReader.from_batches(batch.schema, [batch])


def parse_stac_ndjson_to_arrow(
Expand All @@ -62,7 +68,7 @@ def parse_stac_ndjson_to_arrow(
chunk_size: int = DEFAULT_JSON_CHUNK_SIZE,
schema: pa.Schema | None = None,
limit: int | None = None,
) -> Iterator[pa.RecordBatch]:
) -> pa.RecordBatchReader:
"""
Convert one or more newline-delimited JSON STAC files to a generator of Arrow
RecordBatches.
Expand All @@ -81,39 +87,77 @@ def parse_stac_ndjson_to_arrow(
Keyword Args:
limit: The maximum number of JSON Items to use for schema inference

Yields:
Arrow RecordBatch with a single chunk of Item data.
Returns:
pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches.
"""
# If the schema was not provided, then we need to load all data into memory at once
# to perform schema resolution.
if schema is None:
inferred_schema = InferredSchema()
inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit)
inferred_schema.manual_updates()
yield from parse_stac_ndjson_to_arrow(
return parse_stac_ndjson_to_arrow(
path, chunk_size=chunk_size, schema=inferred_schema
)
return

if isinstance(schema, InferredSchema):
schema = schema.inner

for batch in read_json_chunked(path, chunk_size=chunk_size):
yield stac_items_to_arrow(batch, schema=schema)
batches_iter = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine, but I wanted to clarify a couple things:

  1. Do we have 2 schemas here: in input schema (named schema) and an output schema (named resolved_schema)?
  2. Could we somehow derive resolved_schema from just the input schema, and not from actual data? Something like stac_items_to_arrow([], schema=schema)? Or do we need actual items to figure out the resolved schema?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes. It's the difference between the StacJsonBatch

    class StacJsonBatch:
    """
    An Arrow RecordBatch of STAC Items that has been **minimally converted** to Arrow.
    That is, it aligns as much as possible to the raw STAC JSON representation.
    The **only** transformations that have already been applied here are those that are
    necessary to represent the core STAC items in Arrow.
    - `geometry` has been converted to WKB binary
    - `properties.proj:geometry`, if it exists, has been converted to WKB binary
    ISO encoding
    - The `proj:geometry` in any asset properties, if it exists, has been converted to
    WKB binary.
    No other transformations have yet been applied. I.e. all properties are still in a
    top-level `properties` struct column.
    """
    and StacArrowBatch
    class StacArrowBatch:
    """
    An Arrow RecordBatch of STAC Items that has been processed to match the
    STAC-GeoParquet specification.
    """

    One is the schema of the input data, which is as close to the original STAC JSON as possible (only with geometry pre-coerced to WKB), and the other is the schema of the output data, after any STAC GeoParquet transformations.

  2. No, not as it stands, and it's very annoying. I think the main blocker to this is that we transform the bounding box from an arrow List to an arrow Struct (which we do to take advantage of GeoParquet 1.1, which defines a bounding box struct column that can be used for predicate pushdown). However we don't know in advance whether the bounding box of each Item is 2D or 3D, and so we don't know in advance how many struct fields to create.

    This also means that STAC conversion will fail on mixed 2D/3D input. Are there any real-world STAC collections that have mixed 2D/3D bounding boxes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks. We might consider revisiting this later but that all makes sense for now.

Are there any real-world STAC collections that have mixed 2D/3D bounding boxes?

I think we only have one collection with 3D bounding boxes, and that should have 3D for each item.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However we don't know in advance whether the bounding box of each Item is 2D or 3D, and so we don't know in advance how many struct fields to create.

I suppose this is something we could keep track of in our InferredSchema class while we're doing a scan of the input data: keep track of whether bounding boxes are only 2D or only 3D or a mix of the two.

Though if a user passed in their own schema (which describes the input, not the output data, and so describes bbox as a List), they'd also need to pass in whether the bbox is 2D or 3D

stac_items_to_arrow(batch, schema=schema)
for batch in read_json_chunked(path, chunk_size=chunk_size)
)
first_batch = next(batches_iter)
# Need to take this schema from the iterator; the existing `schema` is the schema of
# JSON batch
resolved_schema = first_batch.schema
return pa.RecordBatchReader.from_batches(
resolved_schema, itertools.chain([first_batch], batches_iter)
)


def stac_table_to_items(
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
) -> Iterable[dict]:
"""Convert STAC Arrow to a generator of STAC Item `dict`s.

Args:
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.

Yields:
A STAC `dict` for each input row.
"""
# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
for batch in table.to_batches():
for batch in reader:
clean_batch = StacArrowBatch(batch)
yield from clean_batch.to_json_batch().iter_dicts()


def stac_table_to_ndjson(
table: pa.Table, dest: str | Path | os.PathLike[bytes]
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
dest: str | Path | os.PathLike[bytes],
) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
for batch in table.to_batches():
"""Write STAC Arrow to a newline-delimited JSON file.

Args:
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.
dest: The destination where newline-delimited JSON should be written.
"""

# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

for batch in reader:
clean_batch = StacArrowBatch(batch)
clean_batch.to_json_batch().to_ndjson(dest)

Expand Down
13 changes: 6 additions & 7 deletions stac_geoparquet/arrow/_delta_lake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import itertools
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable

Expand Down Expand Up @@ -47,14 +46,14 @@ def parse_stac_ndjson_to_delta_lake(
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.
"""
batches_iter = parse_stac_ndjson_to_arrow(
record_batch_reader = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
schema = record_batch_reader.schema.with_metadata(
create_geoparquet_metadata(
pa.Table.from_batches([first_batch]), schema_version=schema_version
record_batch_reader.schema, schema_version=schema_version
)
)
combined_iter = itertools.chain([first_batch], batches_iter)
write_deltalake(table_or_uri, combined_iter, schema=schema, engine="rust", **kwargs)
write_deltalake(
table_or_uri, record_batch_reader, schema=schema, engine="rust", **kwargs
)
52 changes: 31 additions & 21 deletions stac_geoparquet/arrow/_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from stac_geoparquet.arrow._crs import WGS84_CRS_JSON
from stac_geoparquet.arrow._schema.models import InferredSchema
from stac_geoparquet.arrow.types import ArrowStreamExportable


def parse_stac_ndjson_to_parquet(
Expand Down Expand Up @@ -43,26 +44,24 @@ def parse_stac_ndjson_to_parquet(
limit: The maximum number of JSON records to convert.
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.
"""

batches_iter = parse_stac_ndjson_to_arrow(
All other keyword args are passed on to
[`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter].
"""
record_batch_reader = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
create_geoparquet_metadata(
pa.Table.from_batches([first_batch]), schema_version=schema_version
)
to_parquet(
record_batch_reader,
output_path=output_path,
schema_version=schema_version,
**kwargs,
)
with pq.ParquetWriter(output_path, schema, **kwargs) as writer:
writer.write_batch(first_batch)
for batch in batches_iter:
writer.write_batch(batch)


def to_parquet(
table: pa.Table,
where: Any,
table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable,
output_path: str | Path,
*,
schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS = DEFAULT_PARQUET_SCHEMA_VERSION,
**kwargs: Any,
Expand All @@ -72,22 +71,33 @@ def to_parquet(
This writes metadata compliant with either GeoParquet 1.0 or 1.1.

Args:
table: The table to write to Parquet
where: The destination for saving.
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
[Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
A RecordBatchReader or stream object will not be materialized in memory.
output_path: The destination for saving.

Keyword Args:
schema_version: GeoParquet specification version; if not provided will default
to latest supported version.

All other keyword args are passed on to
[`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter].
"""
metadata = table.schema.metadata or {}
metadata.update(create_geoparquet_metadata(table, schema_version=schema_version))
table = table.replace_schema_metadata(metadata)
# Coerce to record batch reader to avoid materializing entire stream
reader = pa.RecordBatchReader.from_stream(table)

pq.write_table(table, where, **kwargs)
schema = reader.schema.with_metadata(
create_geoparquet_metadata(reader.schema, schema_version=schema_version)
)
with pq.ParquetWriter(output_path, schema, **kwargs) as writer:
for batch in reader:
writer.write_batch(batch)


def create_geoparquet_metadata(
table: pa.Table,
schema: pa.Schema,
*,
schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS,
) -> dict[bytes, bytes]:
Expand Down Expand Up @@ -116,7 +126,7 @@ def create_geoparquet_metadata(
"primary_column": "geometry",
}

if "proj:geometry" in table.schema.names:
if "proj:geometry" in schema.names:
# Note we don't include proj:bbox as a covering here for a couple different
# reasons. For one, it's very common for the projected geometries to have a
# different CRS in each row, so having statistics for proj:bbox wouldn't be
Expand Down
5 changes: 5 additions & 0 deletions stac_geoparquet/arrow/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Protocol


class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... # noqa
Loading