Skip to content

Commit

Permalink
Support geoarrow array input into viz() (#427)
Browse files Browse the repository at this point in the history
Closes #425, blocked
on apache/arrow#38010 (comment).
The main issue is that we need a reliable way to maintain the geoarrow
extension metadata through FFI. The easiest way would be if `pa.field()`
were able to support `__arrow_c_schema__` input. Or alternatively, one
option is to have a context manager of sorts to register global pyarrow
geoarrow extension arrays, and then deregister them after use.
  • Loading branch information
kylebarron authored Mar 25, 2024
1 parent 5bdd908 commit 1d6491f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
56 changes: 56 additions & 0 deletions lonboard/_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class GeoInterfaceProtocol(Protocol):
@property
def __geo_interface__(self) -> dict: ...

class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> Tuple[object, object]: ...

class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(
self, requested_schema: object | None = None
Expand Down Expand Up @@ -187,6 +192,11 @@ def create_layer_from_data_input(
if isinstance(data, shapely.geometry.base.BaseGeometry):
return _viz_shapely_scalar(data, **kwargs)

# Anything with __arrow_c_array__
if hasattr(data, "__arrow_c_array__"):
data = cast("ArrowArrayExportable", data)
return _viz_geoarrow_array(data, **kwargs)

# Anything with __arrow_c_stream__
if hasattr(data, "__arrow_c_stream__"):
data = cast("ArrowStreamExportable", data)
Expand Down Expand Up @@ -296,6 +306,52 @@ def _viz_geo_interface(
raise ValueError(f"type '{geo_interface_type}' not supported.")


def _viz_geoarrow_array(
data: ArrowArrayExportable,
**kwargs,
) -> Union[ScatterplotLayer, PathLayer, SolidPolygonLayer]:
schema_capsule, array_capsule = data.__arrow_c_array__()

# If the user doesn't have pyarrow extension types registered for geoarrow types,
# `pa.array()` will lose the extension metadata. Instead, we manually persist the
# extension metadata by extracting both the field and the array.

class ArrayHolder:
schema_capsule: object
array_capsule: object

def __init__(self, schema_capsule, array_capsule) -> None:
self.schema_capsule = schema_capsule
self.array_capsule = array_capsule

def __arrow_c_array__(self, requested_schema):
return self.schema_capsule, self.array_capsule

if not hasattr(pa.Field, "_import_from_c_capsule"):
raise KeyError(
"Incompatible version of pyarrow: pa.Field does not have"
" _import_from_c_capsule method"
)

field = pa.Field._import_from_c_capsule(schema_capsule)
array = pa.array(ArrayHolder(field.__arrow_c_schema__(), array_capsule))
schema = pa.schema([field.with_name("geometry")])
table = pa.Table.from_arrays([array], schema=schema)

num_rows = len(array)
if num_rows <= np.iinfo(np.uint8).max:
arange_col = np.arange(num_rows, dtype=np.uint8)
elif num_rows <= np.iinfo(np.uint16).max:
arange_col = np.arange(num_rows, dtype=np.uint16)
elif num_rows <= np.iinfo(np.uint32).max:
arange_col = np.arange(num_rows, dtype=np.uint32)
else:
arange_col = np.arange(num_rows, dtype=np.uint64)

table = table.append_column("row_index", pa.array(arange_col))
return _viz_geoarrow_table(table, **kwargs)


def _viz_geoarrow_table(
table: pa.Table,
*,
Expand Down
23 changes: 22 additions & 1 deletion tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import geodatasets
import geopandas as gpd
from geoarrow.rust.core import read_pyogrio
from pyogrio.raw import read_arrow

from lonboard import SolidPolygonLayer, viz


def test_viz_wkb_geoarrow():
def test_viz_wkb_pyarrow():
path = geodatasets.get_path("naturalearth.land")
meta, table = read_arrow(path)
map_ = viz(table)
Expand Down Expand Up @@ -37,3 +38,23 @@ def __geo_interface__(self):
map_ = viz(geo_interface_obj)

assert isinstance(map_.layers[0], SolidPolygonLayer)


def test_viz_geoarrow_rust_table():
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
map_ = viz(table)
assert isinstance(map_.layers[0], SolidPolygonLayer)


def test_viz_geoarrow_rust_array():
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
map_ = viz(table.geometry.chunk(0))
assert isinstance(map_.layers[0], SolidPolygonLayer)


def test_viz_geoarrow_rust_wkb_array():
table = read_pyogrio(geodatasets.get_path("naturalearth.land"))
arr = table.geometry.chunk(0)
wkb_arr = arr.to_wkb()
map_ = viz(wkb_arr)
assert isinstance(map_.layers[0], SolidPolygonLayer)

0 comments on commit 1d6491f

Please sign in to comment.