Skip to content

Commit

Permalink
add patch url arg
Browse files Browse the repository at this point in the history
  • Loading branch information
floriscalkoen committed Feb 25, 2025
1 parent b64c738 commit 204904d
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/coastpy/eo/typology.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from collections.abc import Callable
from functools import partial
from typing import Literal

import dask.bag as db
Expand Down Expand Up @@ -371,15 +373,17 @@ def chip_from_transect(
return dataset


def load_stac_gpq_item_xr(stac_gpq_item: gpd.GeoDataFrame) -> xr.Dataset | None:
def load_stac_gpq_item_xr(
stac_gpq_item: gpd.GeoDataFrame, patch_url: Callable | None = None
) -> xr.Dataset | None:
"""Converts a STAC GeoParquet row into an Xarray dataset."""
if len(stac_gpq_item) != 1:
raise ValueError("Expected a single STAC item, but got multiple.")

try:
ds = odc.stac.load(stac_geoparquet.to_item_collection(stac_gpq_item)).squeeze(
drop=True
)
ds = odc.stac.load(
stac_geoparquet.to_item_collection(stac_gpq_item), patch_url=patch_url
).squeeze(drop=True)
ds = ds.drop_vars("spatial_ref", errors="ignore")

# These are STAC GeoParquet that cannot be added as coordinates to the dataset
Expand Down Expand Up @@ -408,19 +412,21 @@ def load_stac_gpq_item_xr(stac_gpq_item: gpd.GeoDataFrame) -> xr.Dataset | None:
return None


def load_stac_xr(df: gpd.GeoDataFrame, use_dask=False) -> xr.Dataset:
def load_stac_xr(
df: gpd.GeoDataFrame, use_dask=False, patch_url: Callable | None = None
) -> xr.Dataset:
"""Loads STAC GeoParquet training items into an Xarray dataset, optionally using Dask Bag for efficiency."""

if use_dask:
bag = db.from_sequence([df.iloc[[i]] for i in range(len(df))])
delayed_datasets = bag.map(load_stac_gpq_item_xr)
delayed_datasets = bag.map(partial(load_stac_gpq_item_xr, patch_url=patch_url))
datasets = delayed_datasets.compute()

else:
datasets = []
for i in range(len(df)):
sample = df.iloc[[i]]
ds = load_stac_gpq_item_xr(sample)
ds = load_stac_gpq_item_xr(sample, patch_url=patch_url)
datasets.append(ds)

if not datasets:
Expand Down

0 comments on commit 204904d

Please sign in to comment.