From 1e4d737e3898779af24cd159eec493c27dac10f9 Mon Sep 17 00:00:00 2001 From: Scott Henderson Date: Fri, 5 Mar 2021 17:40:20 -0800 Subject: [PATCH 1/3] progress to item stacking --- intake_stac/__init__.py | 5 ++ intake_stac/catalog.py | 176 ++++++++++++++++++++++-------------- intake_stac/drivers.py | 193 ++++++++++++++++++++++++++++++++++++++++ setup.py | 1 + 4 files changed, 309 insertions(+), 66 deletions(-) create mode 100644 intake_stac/drivers.py diff --git a/intake_stac/__init__.py b/intake_stac/__init__.py index 139f1d1..6c7f71a 100644 --- a/intake_stac/__init__.py +++ b/intake_stac/__init__.py @@ -2,6 +2,11 @@ from pkg_resources import DistributionNotFound, get_distribution from .catalog import StacCatalog, StacCollection, StacItem, StacItemCollection # noqa: F401 +from .drivers import RioxarraySource # noqa: F401 + +# NOTE: "The drivers ['rioxarray'] do not specify entry_points" but in setup.py... +# intake.register_driver('rioxarray', RioxarraySource) +# register_container('xarray', RioxarraySource) (override intake_xarray?) try: __version__ = get_distribution(__name__).version diff --git a/intake_stac/catalog.py b/intake_stac/catalog.py index 8ed746f..4b05050 100644 --- a/intake_stac/catalog.py +++ b/intake_stac/catalog.py @@ -11,7 +11,7 @@ # STAC catalog asset 'type' determines intake driver: # https://github.com/radiantearth/stac-spec/blob/master/item-spec/item-spec.md#media-types default_type = 'application/rasterio' -default_driver = 'rasterio' +default_driver = 'rioxarray' drivers = { 'application/netcdf': 'netcdf', @@ -20,13 +20,13 @@ 'application/x-parquet': 'parquet', 'application/x-hdf': 'netcdf', 'application/x-hdf5': 'netcdf', - 'application/rasterio': 'rasterio', - 'image/vnd.stac.geotiff': 'rasterio', - 'image/vnd.stac.geotiff; cloud-optimized=true': 'rasterio', - 'image/x.geotiff': 'rasterio', - 'image/tiff; application=geotiff': 'rasterio', - 'image/tiff; application=geotiff; profile=cloud-optimized': 'rasterio', # noqa: E501 - 'image/jp2': 'rasterio', + 'application/rasterio': 'rioxarray', + 'image/vnd.stac.geotiff': 'rioxarray', + 'image/vnd.stac.geotiff; cloud-optimized=true': 'rioxarray', + 'image/x.geotiff': 'rioxarray', + 'image/tiff; application=geotiff': 'rioxarray', + 'image/tiff; application=geotiff; profile=cloud-optimized': 'rioxarray', # noqa: E501 + 'image/jp2': 'rioxarray', 'image/png': 'xarray_image', 'image/jpg': 'xarray_image', 'image/jpeg': 'xarray_image', @@ -101,6 +101,60 @@ def serialize(self): """ return self.yaml() + def stack_items( + self, items, bands, path_as_pattern=None, concat_dim='band', override_coords=None + ): + """ + Experimental. create an xarray dataarray from a bunch of items + + currently not aware of CRS + + Get time coordinate from: + item._stac_obj.datetime + item.metadata['datetime'] + + probably not very efficient / doesn't scale b/c opens files to read metadata in serial + """ + # 1. iterate over items, call stack_bands() + # 2. merge datasets + # common name mapping from first item + common2band = self[items[0]]._get_band_name_mapping(bands) + hrefs = [] + for item in items: + # NOTE: how to speed up this iteration? + # https://github.com/intake/intake-stac/issues/66 + # print(self[item].metadata['proj:epsg']) + # source = self[item].stack_bands(bands=bands) + # call to_dask() within function and xarray merge? + assets = self[item]._stac_obj.assets + for band in bands: + # same as stack_bands + if band in assets: + asset = assets.get(band) + elif band in common2band: + asset = assets.get(common2band[band]) + else: + raise ValueError( + f'Band "{band}" not found in asset ids\n({common2band.values()})\n \ + or common_names\n({common2band.keys()})' + ) + print(asset.href) + hrefs.append(asset.href) + + configDict = {} + configDict['name'] = 'stack' + configDict['description'] = 'stack of assets from multiple items' + configDict['args'] = dict( + chunks={}, + concat_dim=concat_dim, + path_as_pattern=path_as_pattern, + urlpath=hrefs, + override_coords=override_coords, + ) + configDict['metadata'] = {'items': items, 'bands': bands} + + return CombinedAssets(configDict) + class StacCatalog(AbstractStacCatalog): """ @@ -108,6 +162,7 @@ class StacCatalog(AbstractStacCatalog): https://pystac.readthedocs.io/en/latest/api.html?#catalog-spec """ + # NOTE: name must match driver in setup.py entrypoints name = 'stac_catalog' _stac_cls = pystac.Catalog @@ -172,7 +227,7 @@ class StacItemCollection(AbstractStacCatalog): https://pystac.readthedocs.io/en/latest/api.html?#single-file-stac-extension """ - name = 'stac_itemcollection' + name = 'stac_item_collection' _stac_cls = pystac.Catalog def _load(self): @@ -241,33 +296,35 @@ def _get_metadata(self, **kwargs): metadata.update(kwargs) return metadata - def _get_band_info(self): + def _get_band_name_mapping(self, bands): """ - Return list of band info dictionaries (name, common_name, etc.)... + Return dictionary mapping common name to asset name + eo:bands extension has [{'name': 'B01', 'common_name': 'coastal'] + return {'coastal':'B01'} """ - band_info = [] - try: - # NOTE: ensure we test these scenarios - # FileNotFoundError: [Errno 2] No such file or directory: '/catalog.json' + common2band = {} + # 1. try to get directly from item metadata + if 'eo' in self._stac_obj.stac_extensions: + eo = self._stac_obj.ext['eo'] + for band in eo.bands: + common2band[band.common_name] = band.name + + # 2. go a level up to collection metadata + if common2band == {}: collection = self._stac_obj.get_collection() - if 'item-assets' in collection.stac_extensions: - for val in collection.ext['item_assets']: - if 'eo:bands' in val: - band_info.append(val.get('eo:bands')[0]) - else: - band_info = collection.summaries['eo:bands'] + # Can simplify after item-assets extension implemented in Pystac + # https://github.com/stac-utils/pystac/issues/132 + for asset, meta in collection.extra_fields['item_assets'].items(): + eo = meta.get('eo:bands') + if eo: + for entry in eo: + common_name = entry.get('common_name') + if common_name: + common2band[common_name] = asset - except Exception: - for band in self._stac_obj.ext['eo'].get_bands(): - band_info.append(band.to_dict()) - finally: - if not band_info: - raise ValueError( - 'Unable to parse "eo:bands" information from STAC Collection or Item Assets' - ) - return band_info + return common2band - def stack_bands(self, bands, path_as_pattern=None, concat_dim='band'): + def stack_bands(self, bands, path_as_pattern=None, concat_dim='band', override_coords=None): """ Stack the listed bands over the ``band`` dimension. @@ -284,7 +341,7 @@ def stack_bands(self, bands, path_as_pattern=None, concat_dim='band'): Parameters ---------- bands : list of strings representing the different bands - (e.g. ['B4', B5'], ['red', 'nir']). + (assset id or eo:bands "common_name" e.g. ['B4', B5'], ['red', 'nir']) Returns ------- @@ -298,50 +355,37 @@ def stack_bands(self, bands, path_as_pattern=None, concat_dim='band'): stack = item.stack_bands(['B4','B5'], path_as_pattern='{band}.TIF') da = stack(chunks=dict(band=1, x=2048, y=2048)).to_dask() """ - if 'eo' not in self._stac_obj.stac_extensions: - raise ValueError('STAC Item must implement "eo" extension to use this method') - - band_info = self._get_band_info() configDict = {} metadatas = {} - titles = [] + item_metadata = self._stac_obj.properties hrefs = [] - types = [] + common2band = self._get_band_name_mapping(bands) assets = self._stac_obj.assets for band in bands: - # band can be band id, name or common_name if band in assets: - info = next((b for b in band_info if b.get('id', b.get('name')) == band), None,) + asset = assets.get(band) + elif band in common2band: + asset = assets.get(common2band[band]) else: - info = next((b for b in band_info if b.get('common_name') == band), None) - if info is not None: - band = info.get('id', info.get('name')) - - if band not in assets or info is None: - valid_band_names = [] - for b in band_info: - valid_band_names.append(b.get('id', b.get('name'))) - valid_band_names.append(b.get('common_name')) raise ValueError( - f'{band} not found in list of eo:bands in collection.' - f'Valid values: {sorted(list(set(valid_band_names)))}' + f'Band "{band}" not found in asset ids\n({common2band.values()})\n \ + or common_names\n({common2band.keys()})' ) - asset = assets.get(band) - metadatas[band] = asset.to_dict() - titles.append(band) - types.append(asset.media_type) - hrefs.append(asset.href) - unique_types = set(types) - if len(unique_types) != 1: - raise ValueError( - f'Stacking failed: bands must have type, multiple found: {unique_types}' - ) + # map *HREF* to metadata to do fancy things when opening it + asset_metadata = asset.properties + metadatas[asset.href] = {**item_metadata, **asset_metadata} + hrefs.append(asset.href) configDict['name'] = '_'.join(bands) - configDict['description'] = ', '.join(titles) + configDict['description'] = ', '.join(bands) + # NOTE: these are args for driver __init__ method configDict['args'] = dict( - chunks={}, concat_dim=concat_dim, path_as_pattern=path_as_pattern, urlpath=hrefs + chunks={}, + concat_dim=concat_dim, + path_as_pattern=path_as_pattern, + urlpath=hrefs, + override_coords=override_coords, ) configDict['metadata'] = metadatas @@ -443,7 +487,7 @@ def _get_driver(self, asset): ) entry_type = asset.media_type - # if mimetype not registered try rasterio driver + # if mimetype not registered try rioxarray driver driver = drivers.get(entry_type, default_driver) return driver @@ -453,7 +497,7 @@ def _get_args(self, asset, driver): Optional keyword arguments to pass to intake driver """ args = {'urlpath': asset.href} - if driver in ['netcdf', 'rasterio', 'xarray_image']: + if driver in ['netcdf', 'rasterio', 'rioxarray', 'xarray_image']: # NOTE: force using dask? args.update(chunks={}) @@ -472,7 +516,7 @@ def __init__(self, configDict): super().__init__( name=configDict['name'], description=configDict['description'], - driver='rasterio', # stack_bands only relevant to rasterio driver? + driver='rioxarray', direct_access=True, args=configDict['args'], metadata=configDict['metadata'], diff --git a/intake_stac/drivers.py b/intake_stac/drivers.py new file mode 100644 index 0000000..93da728 --- /dev/null +++ b/intake_stac/drivers.py @@ -0,0 +1,193 @@ +from intake.source.base import DataSource, PatternMixin, Schema +from intake.source.utils import reverse_formats +from pkg_resources import get_distribution + +__version__ = get_distribution('intake_stac').version + + +class RioxarraySource(DataSource, PatternMixin): + """Open a xarray dataset via Rioxarray. + This creates an xarray.DataArray https://github.com/corteva/rioxarray + + Parameters + ---------- + urlpath: str or iterable, location of data + May be a local path, or remote path if including a protocol specifier + such as ``'s3://'``. May include glob wildcards or format pattern strings. + Must be a format supported by rasterIO (normally GeoTiff). + Some examples: + - ``{{ CATALOG_DIR }}data/RGB.tif`` + - ``s3://data/*.tif`` + - ``s3://data/landsat8_band{band}.tif`` + - ``s3://data/{location}/landsat8_band{band}.tif`` + - ``{{ CATALOG_DIR }}data/landsat8_{start_date:%Y%m%d}_band{band}.tif`` + chunks: None or int or dict, optional + Chunks is used to load the new dataset into dask + arrays. ``chunks={}`` loads the dataset with dask using a single + chunk for all arrays. default `None` loads numpy arrays. + path_as_pattern: bool or str, optional + Whether to treat the path as a pattern (ie. ``data_{field}.tif``) + and create new coodinates in the output corresponding to pattern + fields. If str, is treated as pattern to match on. Default is True. + """ + + name = 'rioxarray' + version = __version__ + container = 'xarray' + partition_access = True + + def __init__( + self, + urlpath, + chunks=None, + concat_dim='concat_dim', + override_coords=None, + xarray_kwargs=None, + metadata=None, + path_as_pattern=True, + storage_options=None, + **kwargs, + ): + self.path_as_pattern = path_as_pattern + self.urlpath = urlpath + self.chunks = chunks + self.dim = concat_dim + # self.storage_options = storage_options or {} #only relevant to fsspec? + self.override_coords = override_coords + self._kwargs = xarray_kwargs or {} + self._ds = None + # if isinstance(self.urlpath, list): + # self._can_be_local = fsspec.utils.can_be_local(self.urlpath[0]) + # else: + # self._can_be_local = fsspec.utils.can_be_local(self.urlpath) + + # Why is this necessary? + super(RioxarraySource, self).__init__(metadata=metadata) + + def _open_files(self, files): + """ + basically open_mfrasterio() + """ + import rioxarray + import xarray as xr + + # not metadata-aware, so this assigns band=1 regardless of true band# + das = [rioxarray.open_rasterio(f, chunks=self.chunks, **self._kwargs) for f in files] + out = xr.concat(das, dim=self.dim) + + # by default map band names to coordinates instead of band=1,1,1 + coords = {} + # NOTE very robust, and requires potentially really long names + # coords = {self.dim: self.name.split('_')} + # band = 1,2,3 instead of band=1,1,1 + # coords = {self.dim: range(1, len(out.coords[self.dim])+1)} + # NOTE that we have all the STAC metadata at our disposal here: + # coords = dict(time = ('time', [self.metadata[f]['datetime'] for f in files])) + + if self.pattern: + pattern_matches = reverse_formats(self.pattern, files) + coords = {self.dim: pattern_matches[self.dim]} + + if self.override_coords: + coords = {self.dim: self.override_coords} + + return out.assign_coords(**coords).chunk(self.chunks) + + def _open_dataset(self): + import rioxarray + + # if self._can_be_local: + # files = fsspec.open_local(self.urlpath, **self.storage_options) + # else: + # pass URLs to delegate remote opening to rasterio library + # files = self.urlpath + # files = fsspec.open(self.urlpath, **self.storage_options).open() + files = self.urlpath + if isinstance(files, list): + self._ds = self._open_files(files) + else: + self._ds = rioxarray.open_rasterio(files, chunks=self.chunks, **self._kwargs) + + # NOTE: don't know what's going on here + # seems overly complicated... + # https://github.com/intake/intake-xarray/issues/20#issuecomment-432782846 + def _get_schema(self): + """Make schema object, which embeds xarray object and some details""" + # from .xarray_container import serialize_zarr_ds + import msgpack + import xarray as xr + + self.urlpath, *_ = self._get_cache(self.urlpath) + + if self._ds is None: + self._open_dataset() + + ds2 = xr.Dataset({'raster': self._ds}) + metadata = { + 'dims': dict(ds2.dims), + 'data_vars': {k: list(ds2[k].coords) for k in ds2.data_vars.keys()}, + 'coords': tuple(ds2.coords.keys()), + 'array': 'raster', + } + # if getattr(self, 'on_server', False): + # metadata['internal'] = serialize_zarr_ds(ds2) + for k, v in self._ds.attrs.items(): + try: + msgpack.packb(v) + metadata[k] = v + except TypeError: + pass + + if hasattr(self._ds.data, 'npartitions'): + npart = self._ds.data.npartitions + else: + npart = None + + self._schema = Schema( + datashape=None, + dtype=str(self._ds.dtype), + shape=self._ds.shape, + npartitions=npart, + extra_metadata=metadata, + ) + + return self._schema + + def read(self): + """Return a version of the xarray with all the data in memory""" + self._load_metadata() + return self._ds.load() + + def read_chunked(self): + """Return xarray object (which will have chunks)""" + self._load_metadata() + return self._ds + + def read_partition(self, i): + """Fetch one chunk of data at tuple index i + """ + import numpy as np + + self._load_metadata() + if not isinstance(i, (tuple, list)): + raise TypeError('For Xarray sources, must specify partition as ' 'tuple') + if isinstance(i, list): + i = tuple(i) + if hasattr(self._ds, 'variables') or i[0] in self._ds.coords: + arr = self._ds[i[0]].data + i = i[1:] + else: + arr = self._ds.data + if isinstance(arr, np.ndarray): + return arr + # dask array + return arr.blocks[i].compute() + + def to_dask(self): + """Return xarray object where variables are dask arrays""" + return self.read_chunked() + + def close(self): + """Delete open file from memory""" + self._ds = None + self._schema = None diff --git a/setup.py b/setup.py index 5963dc9..f194ce9 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ TESTS_REQUIRE = ['pytest >= 2.7.1'] ENTRY_POINTS = { 'intake.drivers': [ + 'rioxarray = intake_stac.drivers:RioxarraySource', 'stac_catalog = intake_stac.catalog:StacCatalog', 'stac_collection = intake_stac.catalog:StacCollection', 'stac_item_collection = intake_stac.catalog:StacItemCollection', From 1059c6b489e04bf523f39ccce44bb4ceb1d996fe Mon Sep 17 00:00:00 2001 From: Scott Henderson Date: Tue, 9 Mar 2021 16:51:43 -0800 Subject: [PATCH 2/3] preliminary stack_items --- intake_stac/catalog.py | 142 +++++++++++++++++------------- intake_stac/drivers.py | 62 +++++++++---- intake_stac/tests/test_catalog.py | 85 +++++++++++++----- 3 files changed, 185 insertions(+), 104 deletions(-) diff --git a/intake_stac/catalog.py b/intake_stac/catalog.py index 4b05050..7252201 100644 --- a/intake_stac/catalog.py +++ b/intake_stac/catalog.py @@ -102,47 +102,55 @@ def serialize(self): return self.yaml() def stack_items( - self, items, bands, path_as_pattern=None, concat_dim='band', override_coords=None + self, items, assets, path_as_pattern=None, concat_dim='band', override_coords=None ): """ - Experimental. create an xarray dataarray from a bunch of items + Experimental. Create an xarray.DataArray from a bunch of STAC Items - currently not aware of CRS + Parameters + ---------- + items: list of STAC item id strings + assets : list of strings representing the different assets + (assset key or eo:bands "common_name" e.g. ['B4', B5'], ['red', 'nir']) + path_as_pattern : pattern string to extract coordinates from asset href + concat_dim : name of concatenation dimension for xarray.DataArray + override_coords : list of custom coordinate names - Get time coordinate from: - item._stac_obj.datetime - item.metadata['datetime'] + Returns + ------- + CombinedAssets instance with mapping of asset names to xarray coordinates - probably not very efficient / doesn't scale b/c opens files to read metadata in serial + Examples + ------- + source = cat.stack_items(['S2A_36MYB_20200814_0_L2A','S2A_36MYB_20200811_0_L2A'], + ['red', 'nir']) + da = stack().to_dask() """ - # 1. iterate over items, call stack_bands() - # 2. merge datasets - # common name mapping from first item - common2band = self[items[0]]._get_band_name_mapping(bands) + common2band = self[items[0]]._get_band_name_mapping() + metadatas = {'items': {}} hrefs = [] for item in items: - # NOTE: how to speed up this iteration? - # https://github.com/intake/intake-stac/issues/66 - # print(self[item].metadata['proj:epsg']) - # source = self[item].stack_bands(bands=bands) - # call to_dask() within function and xarray merge? - assets = self[item]._stac_obj.assets - for band in bands: - # same as stack_bands - if band in assets: - asset = assets.get(band) - elif band in common2band: - asset = assets.get(common2band[band]) + metadatas['items'][item] = {'STAC': self[item].metadata, 'assets': {}} + stac_assets = self[item]._stac_obj.assets + for key in assets: + + if key in stac_assets: + asset = stac_assets.get(key) + elif key in common2band: + asset = stac_assets.get(common2band[key]) else: raise ValueError( - f'Band "{band}" not found in asset ids\n({common2band.values()})\n \ - or common_names\n({common2band.keys()})' + f'Asset "{key}" not found in asset keys {list(common2band.values())}' + f' or eo:bands common_names {list(common2band.keys())}' ) - print(asset.href) + + asset_metadata = asset.properties + asset_metadata['key'] = key + metadatas['items'][item]['assets'][asset.href] = asset_metadata hrefs.append(asset.href) configDict = {} - configDict['name'] = 'stack' + configDict['name'] = 'item_stack' configDict['description'] = 'stack of assets from multiple items' configDict['args'] = dict( chunks={}, @@ -151,9 +159,11 @@ def stack_items( urlpath=hrefs, override_coords=override_coords, ) - configDict['metadata'] = {'items': items, 'bands': bands} + configDict['metadata'] = metadatas + + stack = CombinedAssets(configDict) - return CombinedAssets(configDict) + return stack class StacCatalog(AbstractStacCatalog): @@ -296,12 +306,13 @@ def _get_metadata(self, **kwargs): metadata.update(kwargs) return metadata - def _get_band_name_mapping(self, bands): + def _get_band_name_mapping(self): """ Return dictionary mapping common name to asset name eo:bands extension has [{'name': 'B01', 'common_name': 'coastal'] return {'coastal':'B01'} """ + # NOTE: maybe return entire dataframe w/ central_wavelength etc? common2band = {} # 1. try to get directly from item metadata if 'eo' in self._stac_obj.stac_extensions: @@ -324,61 +335,63 @@ def _get_band_name_mapping(self, bands): return common2band - def stack_bands(self, bands, path_as_pattern=None, concat_dim='band', override_coords=None): + def stack_assets(self, assets, path_as_pattern=None, concat_dim='band', override_coords=None): """ - Stack the listed bands over the ``band`` dimension. + Stack the listed assets over the ``band`` dimension. - This method only works for STAC Items using the 'eo' Extension - https://github.com/radiantearth/stac-spec/tree/master/extensions/eo - - NOTE: This method is not aware of geotransform information. It *assumes* - bands for a given STAC Item have the same coordinate reference system (CRS). + WARNING: This method is not aware of geotransform information. It *assumes* + assets for a given STAC Item have the same coordinate reference system (CRS). This is usually the case for a given multi-band satellite acquisition. - Coordinate alignment is performed automatically upon calling the - `to_dask()` method to load into an Xarray DataArray if bands have diffent - ground sample distance (gsd) or array shapes. + + See the following documentation for dealing with different CRS or GSD: + http://xarray.pydata.org/en/stable/interpolation.html + https://corteva.github.io/rioxarray/stable/examples/reproject_match.html Parameters ---------- - bands : list of strings representing the different bands - (assset id or eo:bands "common_name" e.g. ['B4', B5'], ['red', 'nir']) + assets : list of strings representing the different assets + (assset key or eo:bands "common_name" e.g. ['B4', B5'], ['red', 'nir']) + path_as_pattern : pattern string to extract coordinates from asset href + concat_dim : name of concatenation dimension for xarray.DataArray + override_coords : list of custom coordinate names Returns ------- - StacAsset with mapping of Asset names to Xarray bands + CombinedAssets instance with mapping of asset names to xarray coordinates Examples ------- - stack = item.stack_bands(['nir','red']) - da = stack(chunks=dict(band=1, x=2048, y=2048)).to_dask() + stack = item.stack_assets(['nir','red']) + da = stack(chunks=True).to_dask() - stack = item.stack_bands(['B4','B5'], path_as_pattern='{band}.TIF') + stack = item.stack_assets(['B4','B5'], path_as_pattern='{band}.TIF') da = stack(chunks=dict(band=1, x=2048, y=2048)).to_dask() """ configDict = {} - metadatas = {} - item_metadata = self._stac_obj.properties + metadatas = {'items': {self.name: {'STAC': self.metadata, 'assets': {}}}} hrefs = [] - common2band = self._get_band_name_mapping(bands) - assets = self._stac_obj.assets - for band in bands: - if band in assets: - asset = assets.get(band) - elif band in common2band: - asset = assets.get(common2band[band]) + common2band = self._get_band_name_mapping() + stac_assets = self._stac_obj.assets + for key in assets: + if key in stac_assets: + asset = stac_assets.get(key) + elif key in common2band: + asset = stac_assets.get(common2band[key]) else: raise ValueError( - f'Band "{band}" not found in asset ids\n({common2band.values()})\n \ - or common_names\n({common2band.keys()})' + f'Asset "{key}" not found in asset keys {list(common2band.values())}' + f' or eo:bands common_names {list(common2band.keys())}' ) # map *HREF* to metadata to do fancy things when opening it asset_metadata = asset.properties - metadatas[asset.href] = {**item_metadata, **asset_metadata} + asset_metadata['key'] = key + asset_metadata['item'] = self.name + metadatas['items'][self.name]['assets'][asset.href] = asset_metadata hrefs.append(asset.href) - configDict['name'] = '_'.join(bands) - configDict['description'] = ', '.join(bands) + configDict['name'] = self.name + configDict['description'] = ', '.join(assets) # NOTE: these are args for driver __init__ method configDict['args'] = dict( chunks={}, @@ -389,7 +402,10 @@ def stack_bands(self, bands, path_as_pattern=None, concat_dim='band', override_c ) configDict['metadata'] = metadatas - return CombinedAssets(configDict) + # instantiate to allow item.stack_assets(['red','nir']).to_dask() ? + stack = CombinedAssets(configDict) + + return stack class StacAsset(LocalCatalogEntry): @@ -509,9 +525,11 @@ class CombinedAssets(LocalCatalogEntry): Maps multiple STAC Item Assets to 1 Intake Catalog Entry """ + _stac_cls = None + def __init__(self, configDict): """ - configDict = intake Entry dictionary from stack_bands() method + configDict = intake Entry intialization dictionary """ super().__init__( name=configDict['name'], diff --git a/intake_stac/drivers.py b/intake_stac/drivers.py index 93da728..4217f9e 100644 --- a/intake_stac/drivers.py +++ b/intake_stac/drivers.py @@ -64,34 +64,57 @@ def __init__( # Why is this necessary? super(RioxarraySource, self).__init__(metadata=metadata) - def _open_files(self, files): + def _open_items(self): """ - basically open_mfrasterio() + Use STAC metadata to intelligently concatenate multiple items """ + import xarray as xr + + data_arrays = [] + for item_id in self.metadata['items'].keys(): + files = self.metadata['items'][item_id]['assets'].keys() + data_arrays.append(self._open_assets(files, item_id)) + + # by default concatenate items in time + ds = xr.concat(data_arrays, dim='item').swap_dims({'item': 'time'}) + ds.name = None + + return ds + + def _open_assets(self, files, item_id=None): + """ + use STAC metadata to intelligently concatenate multiple assets + """ + import rioxarray import xarray as xr + # Re-arranged metadata from intake-stac CombinedAssets + metadata = self.metadata['items'][item_id] + # not metadata-aware, so this assigns band=1 regardless of true band# + # Note: wrap with dask.delayed for parallel loading? das = [rioxarray.open_rasterio(f, chunks=self.chunks, **self._kwargs) for f in files] out = xr.concat(das, dim=self.dim) - # by default map band names to coordinates instead of band=1,1,1 - coords = {} - # NOTE very robust, and requires potentially really long names - # coords = {self.dim: self.name.split('_')} - # band = 1,2,3 instead of band=1,1,1 - # coords = {self.dim: range(1, len(out.coords[self.dim])+1)} - # NOTE that we have all the STAC metadata at our disposal here: - # coords = dict(time = ('time', [self.metadata[f]['datetime'] for f in files])) + # NOTE: no time zone conversion logic (assume UTC) + coords = {'item': item_id, 'time': metadata['STAC']['datetime'].replace(tzinfo=None)} + + # by default assign asset keys as coordinate values + coords[self.dim] = [metadata['assets'][f]['key'] for f in files] if self.pattern: pattern_matches = reverse_formats(self.pattern, files) - coords = {self.dim: pattern_matches[self.dim]} + coords[self.dim] = pattern_matches[self.dim] if self.override_coords: - coords = {self.dim: self.override_coords} + coords[self.dim] = self.override_coords + + # copy item property metadata as attribute ? + # self._ds.attrs['STAC'] = self.metadata + out.name = item_id - return out.assign_coords(**coords).chunk(self.chunks) + return out.assign_coords(**coords) def _open_dataset(self): import rioxarray @@ -102,14 +125,15 @@ def _open_dataset(self): # pass URLs to delegate remote opening to rasterio library # files = self.urlpath # files = fsspec.open(self.urlpath, **self.storage_options).open() - files = self.urlpath - if isinstance(files, list): - self._ds = self._open_files(files) + if isinstance(self.urlpath, list): + if self.name == 'item_stack': + self._ds = self._open_items() + else: + self._ds = self._open_assets(self.urlpath, item_id=self.name) else: - self._ds = rioxarray.open_rasterio(files, chunks=self.chunks, **self._kwargs) + self._ds = rioxarray.open_rasterio(self.urlpath, chunks=self.chunks, **self._kwargs) - # NOTE: don't know what's going on here - # seems overly complicated... + # NOTE: don't know what's going on here... # https://github.com/intake/intake-xarray/issues/20#issuecomment-432782846 def _get_schema(self): """Make schema object, which embeds xarray object and some details""" diff --git a/intake_stac/tests/test_catalog.py b/intake_stac/tests/test_catalog.py index e601612..5c779a2 100644 --- a/intake_stac/tests/test_catalog.py +++ b/intake_stac/tests/test_catalog.py @@ -1,5 +1,5 @@ import datetime -import os.path +import os import sys import intake @@ -12,6 +12,11 @@ from intake_stac.catalog import CombinedAssets, StacAsset here = os.path.dirname(__file__) +import os + +# Set environment variables for network tests +os.environ['GDAL_DISABLE_READDIR_ON_OPEN'] = 'EMPTY_DIR' +os.environ['AWS_NO_SIGN_REQUEST'] = 'YES' # sat-stac examples # ----- @@ -129,59 +134,61 @@ def test_cat_from_item(pystac_item): assert 'B5' in cat -def test_cat_item_stacking(pystac_item): +def test_stack_assets(pystac_item): item = StacItem(pystac_item) list_of_bands = ['B1', 'B2'] - new_entry = item.stack_bands(list_of_bands) + new_entry = item.stack_assets(list_of_bands) assert isinstance(new_entry, CombinedAssets) assert new_entry._description == 'B1, B2' - assert new_entry.name == 'B1_B2' + assert new_entry.name == 'LC08_L1TP_152038_20200611_20200611_01_RT' new_da = new_entry().to_dask() assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y'] -def test_cat_item_stacking_using_common_name(pystac_item): +def test_stack_assets_using_common_name(pystac_item): item = StacItem(pystac_item) list_of_bands = ['coastal', 'blue'] - new_entry = item.stack_bands(list_of_bands) + new_entry = item.stack_assets(list_of_bands) assert isinstance(new_entry, CombinedAssets) - assert new_entry._description == 'B1, B2' - assert new_entry.name == 'coastal_blue' + assert new_entry._description == 'coastal, blue' + assert new_entry.name == 'LC08_L1TP_152038_20200611_20200611_01_RT' new_da = new_entry().to_dask() assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y'] -def test_cat_item_stacking_path_as_pattern(pystac_item): +def test_stack_assets_path_as_pattern(pystac_item): item = StacItem(pystac_item) list_of_bands = ['B1', 'B2'] - new_entry = item.stack_bands(list_of_bands, path_as_pattern='{}{band:2}.TIF') + new_entry = item.stack_assets(list_of_bands, path_as_pattern='{}{band:2}.TIF') assert isinstance(new_entry, CombinedAssets) new_da = new_entry().to_dask() assert (new_da.band == ['B1', 'B2']).all() -def test_cat_item_stacking_dims_of_different_type_raises_error(pystac_item): - item = StacItem(pystac_item) - list_of_bands = ['B1', 'ANG'] - with pytest.raises(ValueError, match=('ANG not found in list of eo:bands in collection')): - item.stack_bands(list_of_bands) +# reconsider this test b/c you could have types= 'image/x.geotiff', 'image/tiff' +# def test_cat_item_stacking_dims_of_different_type_raises_error(pystac_item): +# item = StacItem(pystac_item) +# list_of_bands = ['B1', 'ANG'] +# with pytest.raises(ValueError, match=('ANG not found in list of eo:bands in collection')): +# item.stack_bands(list_of_bands) -def test_cat_item_stacking_dims_with_nonexistent_band_raises_error(pystac_item,): # noqa: E501 +def test_stack_assets_dims_with_nonexistent_band_raises_error(pystac_item): # noqa: E501 item = StacItem(pystac_item) list_of_bands = ['B1', 'foo'] - with pytest.raises(ValueError, match="'B8', 'B9', 'blue', 'cirrus'"): - item.stack_bands(list_of_bands) + with pytest.raises(ValueError, match='Asset "foo" not found in asset keys'): + item.stack_assets(list_of_bands) -def test_cat_item_stacking_dims_of_different_size_regrids(pystac_item): +# NOTE: this only works b/c CRS and grids are aligned! +def test_stack_assets_dims_of_different_size_regrids(pystac_item): item = StacItem(pystac_item) list_of_bands = ['B1', 'B8'] B1_da = item.B1.to_dask() assert B1_da.shape == (1, 7791, 7651) B8_da = item.B8.to_dask() assert B8_da.shape == (1, 15581, 15301) - new_entry = item.stack_bands(list_of_bands) + new_entry = item.stack_assets(list_of_bands) new_da = new_entry().to_dask() assert new_da.shape == (2, 15581, 15301) assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y'] @@ -195,7 +202,7 @@ def test_asset_describe(pystac_item): assert d['name'] == key assert d['container'] == 'xarray' - assert d['plugin'] == ['rasterio'] + assert d['plugin'] == ['rioxarray'] assert d['args']['urlpath'] == asset.urlpath assert d['description'] == asset.description # NOTE: note sure why asset.metadata has 'catalog_dir' key ? @@ -213,7 +220,7 @@ def test_asset_missing_type(pystac_item): assert d['name'] == key assert d['metadata']['type'] == 'application/rasterio' # default_type assert d['container'] == 'xarray' - assert d['plugin'] == ['rasterio'] + assert d['plugin'] == ['rioxarray'] def test_asset_unknown_type(pystac_item): @@ -226,7 +233,7 @@ def test_asset_unknown_type(pystac_item): assert d['name'] == key assert d['metadata']['type'] == 'unrecognized' assert d['container'] == 'xarray' - assert d['plugin'] == ['rasterio'] + assert d['plugin'] == ['rioxarray'] def test_cat_to_geopandas(pystac_itemcol): @@ -281,3 +288,35 @@ def test_collection_of_collection(): result = StacCollection(parent) result._load() + + +def test_stack_items(): + test_file = os.path.join(here, 'data/1.0.0beta2/earthsearch/single-file-stac.json') + cat = intake.open_stac_item_collection(test_file) + items = ['S2A_36MYB_20200814_0_L2A', 'S2A_36MYB_20200811_0_L2A'] + assets = ['B04', 'B08'] + source = cat.stack_items(items, assets) + assert isinstance(source, CombinedAssets) + assert source.name == 'item_stack' + da = source().to_dask() + assert hasattr(da, 'rio') + assert (da.band == assets).all() + coords = ['band', 'y', 'x', 'spatial_ref', 'item', 'time'] + assert set(coords) == set(list(da.coords)) + assert sorted([dim for dim in da.dims]) == ['band', 'time', 'x', 'y'] + + +def test_stack_items_using_common_name(): + test_file = os.path.join(here, 'data/1.0.0beta2/earthsearch/single-file-stac.json') + cat = intake.open_stac_item_collection(test_file) + items = ['S2A_36MYB_20200814_0_L2A', 'S2A_36MYB_20200811_0_L2A'] + assets = ['red', 'nir'] + source = cat.stack_items(items, assets) + assert isinstance(source, CombinedAssets) + assert source.name == 'item_stack' + da = source().to_dask() + assert hasattr(da, 'rio') + assert len(da.coords['time']) == 2 + assert len(da.coords['item']) == 2 + assert len(da.coords['band']) == 2 + assert (da.band == assets).all() From 132edec748f2f3834d9e6b432018f737942a8399 Mon Sep 17 00:00:00 2001 From: Scott Henderson Date: Tue, 9 Mar 2021 18:52:06 -0800 Subject: [PATCH 3/3] add rioxarray to ci environments --- .github/workflows/binderbadge.yaml | 2 +- ci/environment-3.7.yml | 1 + ci/environment-3.8.yml | 1 + ci/environment-3.9.yml | 1 + ci/environment-dev.yml | 1 + ci/environment-docs.yml | 1 + ci/environment-gui.yml | 1 + ci/environment-unpinned.yml | 1 + ci/environment-upstream.yml | 3 ++- intake_stac/drivers.py | 1 + 10 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/binderbadge.yaml b/.github/workflows/binderbadge.yaml index 82f43ae..a5cd51e 100644 --- a/.github/workflows/binderbadge.yaml +++ b/.github/workflows/binderbadge.yaml @@ -1,7 +1,7 @@ # create a mybinder badge issue comment for testing PRs name: AddBinderBadge on: - pull_request: + pull_request_target: types: [opened, reopened] jobs: build-image-without-pushing: diff --git a/ci/environment-3.7.yml b/ci/environment-3.7.yml index 4bbb9cb..f5999cc 100644 --- a/ci/environment-3.7.yml +++ b/ci/environment-3.7.yml @@ -10,4 +10,5 @@ dependencies: - pystac - pytest-cov - rasterio + - rioxarray - xarray diff --git a/ci/environment-3.8.yml b/ci/environment-3.8.yml index 4aab824..f827c77 100644 --- a/ci/environment-3.8.yml +++ b/ci/environment-3.8.yml @@ -10,4 +10,5 @@ dependencies: - pystac - pytest-cov - rasterio + - rioxarray - xarray diff --git a/ci/environment-3.9.yml b/ci/environment-3.9.yml index 9bdcdff..52e167b 100644 --- a/ci/environment-3.9.yml +++ b/ci/environment-3.9.yml @@ -10,4 +10,5 @@ dependencies: - pystac - pytest-cov - rasterio + - rioxarray - xarray diff --git a/ci/environment-dev.yml b/ci/environment-dev.yml index 8e42c20..78df0da 100644 --- a/ci/environment-dev.yml +++ b/ci/environment-dev.yml @@ -34,6 +34,7 @@ dependencies: - pytoml - pyyaml - rasterio + - rioxarray - recommonmark - requests - sat-search>=0.3 diff --git a/ci/environment-docs.yml b/ci/environment-docs.yml index 7e1fe91..1a86481 100644 --- a/ci/environment-docs.yml +++ b/ci/environment-docs.yml @@ -17,6 +17,7 @@ dependencies: - pystac - pytest-cov - rasterio + - rioxarray - sat-search - scikit-image - sphinx diff --git a/ci/environment-gui.yml b/ci/environment-gui.yml index e8119d3..ecd6953 100644 --- a/ci/environment-gui.yml +++ b/ci/environment-gui.yml @@ -47,6 +47,7 @@ dependencies: - pytoml - pyyaml - rasterio + - rioxarray - recommonmark - requests - sat-search>=0.3 diff --git a/ci/environment-unpinned.yml b/ci/environment-unpinned.yml index fb8680c..5ce5362 100644 --- a/ci/environment-unpinned.yml +++ b/ci/environment-unpinned.yml @@ -10,4 +10,5 @@ dependencies: - pystac - pytest-cov - rasterio + - rioxarray - xarray diff --git a/ci/environment-upstream.yml b/ci/environment-upstream.yml index b870f67..07e8178 100644 --- a/ci/environment-upstream.yml +++ b/ci/environment-upstream.yml @@ -11,7 +11,7 @@ dependencies: - pystac - pytest-cov - rasterio - - xarray + - rioxarray - pip: - git+https://github.com/intake/filesystem_spec.git - git+https://github.com/stac-utils/pystac.git @@ -20,3 +20,4 @@ dependencies: - git+https://github.com/intake/intake-xarray.git - git+https://github.com/intake/intake_geopandas.git - git+https://github.com/intake/intake-parquet.git + - git+https://github.com/pydata/xarray.git diff --git a/intake_stac/drivers.py b/intake_stac/drivers.py index 4217f9e..c39924a 100644 --- a/intake_stac/drivers.py +++ b/intake_stac/drivers.py @@ -56,6 +56,7 @@ def __init__( self.override_coords = override_coords self._kwargs = xarray_kwargs or {} self._ds = None + print(self._kwargs) # if isinstance(self.urlpath, list): # self._can_be_local = fsspec.utils.can_be_local(self.urlpath[0]) # else: