Skip to content

Commit

Permalink
Merge pull request #169 from geo-engine/stream_bands
Browse files Browse the repository at this point in the history
stream raster bands
  • Loading branch information
ChristianBeilschmidt authored Feb 6, 2024
2 parents 4d6f2a3 + 1ae7af3 commit e632cf1
Show file tree
Hide file tree
Showing 9 changed files with 1,379 additions and 55 deletions.
1,202 changes: 1,202 additions & 0 deletions examples/multiband_raster_stream.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions geoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RasterSymbology, VectorSymbology, VectorDataType, VectorResultDescriptor, VectorColumnInfo, \
FeatureDataType, RasterBandDescriptor, DEFAULT_ISO_TIME_FORMAT, RasterColorizer, SingleBandRasterColorizer \

from .util import clamp_datetime_ms_ns
from .workflow import WorkflowId, Workflow, workflow_by_id, register_workflow, get_quota, update_quota
from .raster import RasterTile2D

Expand Down
28 changes: 22 additions & 6 deletions geoengine/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import xarray as xr
import geoengine_openapi_client
import geoengine.types as gety
from geoengine.util import clamp_datetime_ms_ns


class RasterTile2D:
Expand All @@ -17,6 +18,7 @@ class RasterTile2D:
geo_transform: gety.GeoTransform
crs: str
time: gety.TimeInterval
band: int

# pylint: disable=too-many-arguments
def __init__(
Expand All @@ -25,14 +27,16 @@ def __init__(
data: pa.Array,
geo_transform: gety.GeoTransform,
crs: str,
time: gety.TimeInterval
time: gety.TimeInterval,
band: int,
):
'''Create a RasterTile2D object'''
self.size_y, self.size_x = shape
self.data = data
self.geo_transform = geo_transform
self.crs = crs
self.time = time
self.band = band

@property
def shape(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -142,22 +146,31 @@ def coords_y(self, pixel_center=False) -> np.ndarray:
def to_xarray(self, clip_with_bounds: Optional[gety.SpatialBounds] = None) -> xr.DataArray:
'''
Return the raster tile as an xarray.DataArray.
Xarray does not support masked arrays.
Masked pixels are converted to NaNs and the nodata value is set to NaN as well.
Note:
- Xarray does not support masked arrays.
- Masked pixels are converted to NaNs and the nodata value is set to NaN as well.
- Xarray uses numpy's datetime64[ns] which only covers the years from 1678 to 2262.
- Date times that are outside of the defined range are clipped to the limits of the range.
'''

# clamp the dates to the min and max range
clamped_date = clamp_datetime_ms_ns(self.time_start_ms)

array = xr.DataArray(
self.to_numpy_masked_array(),
dims=["y", "x"],
coords={
'x': self.coords_x(pixel_center=True),
'y': self.coords_y(pixel_center=True),
'time': self.time_start_ms, # TODO: incorporate time end?
'time': clamped_date, # TODO: incorporate time end?
'band': self.band,
}
)
array.rio.write_crs(self.crs, inplace=True)

if clip_with_bounds is not None:
array = array.rio.clip_box(*clip_with_bounds.as_bbox_tuple())
array = array.rio.clip_box(*clip_with_bounds.as_bbox_tuple(), auto_expand=True)
array = cast(xr.DataArray, array)

return array
Expand Down Expand Up @@ -189,10 +202,13 @@ def from_ge_record_batch(record_batch: pa.RecordBatch) -> RasterTile2D:

time = gety.TimeInterval.from_response(json.loads(metadata[b'time']))

band = int(metadata[b'band'])

return RasterTile2D(
(y_size, x_size),
arrow_array,
geo_transform,
spatial_reference,
time
time,
band,
)
4 changes: 2 additions & 2 deletions geoengine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def from_response(response: Any) -> TimeInterval:
end = cast(int, response['end']) if 'end' in response and response['end'] is not None else None

return TimeInterval(
datetime.fromtimestamp(start / 1000),
datetime.fromtimestamp(end / 1000) if end is not None else None,
np.datetime64(start, 'ms'),
np.datetime64(end, 'ms') if end is not None else None,
)

start_str = cast(str, response['start'])
Expand Down
24 changes: 24 additions & 0 deletions geoengine/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
'''
Module for utility functions
'''


import numpy as np


def clamp_datetime_ms_ns(value: np.datetime64) -> np.datetime64:
'''Clamp a datetime64[ms] to the range of datetime64[ns] used by xarray'''

min_date = np.datetime64('1678-09-21 00:12:43.145224192', 'ns')
max_date = np.datetime64('2262-04-11 23:47:16.854775807', 'ns')

min_date_ms = min_date.astype('datetime64[ms]')
max_date_ms = max_date.astype('datetime64[ms]')

if value < min_date_ms:
return min_date

if value > max_date_ms:
return max_date

return value.astype('datetime64[ns]')
131 changes: 89 additions & 42 deletions geoengine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
import json
from io import BytesIO
from logging import debug
Expand Down Expand Up @@ -34,7 +35,7 @@
from geoengine.auth import get_session
from geoengine.colorizer import Colorizer
from geoengine.error import GeoEngineException, InputException, MethodNotCalledOnPlotException, \
MethodNotCalledOnRasterException, MethodNotCalledOnVectorException, TypeException
MethodNotCalledOnRasterException, MethodNotCalledOnVectorException
from geoengine import backports
from geoengine.types import ProvenanceEntry, QueryRectangle, ResultDescriptor
from geoengine.tasks import Task, TaskId
Expand Down Expand Up @@ -81,6 +82,61 @@ def __repr__(self) -> str:
return str(self)


class RasterStreamProcessing:
'''
Helper class to process raster stream data
'''

@classmethod
def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
'''Read an Arrow IPC file from a byte array'''

reader = pa.ipc.open_file(arrow_ipc)
# We know from the backend that there is only one record batch
record_batch = reader.get_record_batch(0)
return record_batch

@classmethod
def process_bytes(cls, tile_bytes: Optional[bytes]) -> Optional[RasterTile2D]:
'''Process a tile from a byte array'''

if tile_bytes is None:
return None

# process the received data
record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
tile = RasterTile2D.from_ge_record_batch(record_batch)

return tile

@classmethod
def merge_tiles(cls, tiles: List[xr.DataArray]) -> Optional[xr.DataArray]:
'''Merge a list of tiles into a single xarray'''

if len(tiles) == 0:
return None

# group the tiles by band
tiles_by_band: Dict[int, List[xr.DataArray]] = defaultdict(list)
for tile in tiles:
band = tile.band.item() # assuming 'band' is a coordinate with a single value
tiles_by_band[band].append(tile)

# build one spatial tile per band
combined_by_band = []
for band_tiles in tiles_by_band.values():
combined = xr.combine_by_coords(band_tiles)
# `combine_by_coords` always returns a `DataArray` for single variable input arrays.
# This assertion verifies this for mypy
assert isinstance(combined, xr.DataArray)
combined_by_band.append(combined)

# build one array with all bands and geo coordinates
combined_tile = xr.concat(combined_by_band, dim='band')

return combined_tile


class Workflow:
'''
Holds a workflow id and allows querying data
Expand Down Expand Up @@ -465,26 +521,18 @@ def save_as_dataset(
return Task(TaskId.from_response(response))

async def raster_stream(
self,
query_rectangle: QueryRectangle,
open_timeout: int = 60) -> AsyncIterator[RasterTile2D]:
self,
query_rectangle: QueryRectangle,
open_timeout: int = 60,
bands: Optional[List[int]] = None # TODO: move into query rectangle?
) -> AsyncIterator[RasterTile2D]:
'''Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)'''

def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
reader = pa.ipc.open_file(arrow_ipc)
# We know from the backend that there is only one record batch
record_batch = reader.get_record_batch(0)
return record_batch
if bands is None:
bands = [0]

def process_bytes(tile_bytes: Optional[bytes]) -> Optional[RasterTile2D]:
if tile_bytes is None:
return None

# process the received data
record_batch = read_arrow_ipc(tile_bytes)
tile = RasterTile2D.from_ge_record_batch(record_batch)

return tile
if len(bands) == 0:
raise InputException('At least one band must be specified')

# Currently, it only works for raster results
if not self.__result_descriptor.is_raster_result():
Expand All @@ -500,6 +548,7 @@ def process_bytes(tile_bytes: Optional[bytes]) -> Optional[RasterTile2D]:
'spatialBounds': query_rectangle.bbox_str,
'timeInterval': query_rectangle.time_str,
'spatialResolution': str(query_rectangle.spatial_resolution),
'attributes': ','.join(map(str, bands))
},
).prepare().url

Expand Down Expand Up @@ -539,35 +588,44 @@ async def read_new_bytes() -> Optional[bytes]:
(tile_bytes, tile) = await asyncio.gather(
read_new_bytes(),
# asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
backports.to_thread(process_bytes, tile_bytes),
backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
)

if tile is not None:
yield tile

# process the last tile
tile = process_bytes(tile_bytes)
tile = RasterStreamProcessing.process_bytes(tile_bytes)

if tile is not None:
yield tile

async def raster_stream_into_xarray(
self,
query_rectangle: QueryRectangle,
clip_to_query_rectangle: bool = False,
open_timeout: int = 60) -> xr.DataArray:
self,
query_rectangle: QueryRectangle,
clip_to_query_rectangle: bool = False,
open_timeout: int = 60,
bands: Optional[List[int]] = None # TODO: move into query rectangle?
) -> xr.DataArray:
'''
Stream the workflow result into memory and output a single xarray.
NOTE: You can run out of memory if the query rectangle is too large.
'''

if bands is None:
bands = [0]

if len(bands) == 0:
raise InputException('At least one band must be specified')

tile_stream = self.raster_stream(
query_rectangle,
open_timeout=open_timeout
open_timeout=open_timeout,
bands=bands
)

timesteps: List[xr.DataArray] = []
timestep_xarrays: List[xr.DataArray] = []

spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None

Expand Down Expand Up @@ -595,39 +653,28 @@ async def read_tiles(
# this seems to be the last time step, so just return tiles
return tiles, None

def merge_tiles(tiles: List[xr.DataArray]) -> Optional[xr.DataArray]:
if len(tiles) == 0:
return None

combined_tiles = xr.combine_by_coords(tiles)

if isinstance(combined_tiles, xr.Dataset):
raise TypeException('Internal error: Merging data arrays should result in a data array.')

return combined_tiles

(tiles, remainder_tile) = await read_tiles(None)

while len(tiles):
((new_tiles, new_remainder_tile), new_timestep) = await asyncio.gather(
((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
read_tiles(remainder_tile),
backports.to_thread(merge_tiles, tiles)
backports.to_thread(RasterStreamProcessing.merge_tiles, tiles)
# asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
)

tiles = new_tiles
remainder_tile = new_remainder_tile

if new_timestep is not None:
timesteps.append(new_timestep)
if new_timestep_xarray is not None:
timestep_xarrays.append(new_timestep_xarray)

output: xr.DataArray = cast(
xr.DataArray,
# await asyncio.to_thread( # TODO: use this when min Python version is 3.9
await backports.to_thread(
xr.concat,
# TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
cast(List[xr.Dataset], timesteps),
cast(List[xr.Dataset], timestep_xarrays),
dim='time'
)
)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def setUp(self) -> None:
time=ge.TimeInterval(
start=time
),
band=0,
)

def test_shape(self) -> None:
Expand Down Expand Up @@ -119,6 +120,7 @@ def test_from_ge_record_batch(self) -> None:
"time": json.dumps({
"start": time
}),
"band": "0",
}

batch = pa.RecordBatch.from_arrays([array], names=['data'], metadata=metadata)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
'''Test for utility functions'''

import unittest
import numpy as np
import geoengine as ge


class TypesTests(unittest.TestCase):
"""Types test runner."""

def test_clamp(self):
self.assertEqual(
ge.clamp_datetime_ms_ns(np.datetime64('1500-09-21', 'ms')),
np.datetime64('1678-09-21 00:12:43.145224192', 'ns')
)
self.assertEqual(
ge.clamp_datetime_ms_ns(np.datetime64('-11500-09-21', 'ms')),
np.datetime64('1678-09-21 00:12:43.145224192', 'ns')
)
self.assertEqual(
ge.clamp_datetime_ms_ns(np.datetime64('3000-09-21', 'ms')),
np.datetime64('2262-04-11 23:47:16.854775807', 'ns')
)

self.assertEqual(
ge.clamp_datetime_ms_ns(np.datetime64('2000-01-02 11:22:33.44', 'ms')),
np.datetime64('2000-01-02 11:22:33.44', 'ns')
)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit e632cf1

Please sign in to comment.