Skip to content

Commit

Permalink
Merge pull request #172 from geo-engine/ml_example_to_multi_band
Browse files Browse the repository at this point in the history
Ml example to multi band
  • Loading branch information
michaelmattig authored Mar 6, 2024
2 parents e632cf1 + db4aaac commit 81d2a49
Show file tree
Hide file tree
Showing 7 changed files with 3,148 additions and 3,011 deletions.
2,946 changes: 2,946 additions & 0 deletions examples/s2_field_combination_extern_rf_train/nrw_crop_extern_s2_workflow.ipynb

Large diffs are not rendered by default.

This file was deleted.

104 changes: 103 additions & 1 deletion geoengine/raster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''Raster data types'''
from __future__ import annotations
import json
from typing import Optional, Tuple, Union, cast
from typing import AsyncIterator, List, Optional, Tuple, Union, cast
import numpy as np
import pyarrow as pa
import xarray as xr
Expand Down Expand Up @@ -212,3 +212,105 @@ def from_ge_record_batch(record_batch: pa.RecordBatch) -> RasterTile2D:
time,
band,
)


class RasterTileStack2D:
'''A stack of all the bands of a raster tile as produced by the Geo Engine'''
size_y: int
size_x: int
geo_transform: gety.GeoTransform
crs: str
time: gety.TimeInterval
data: List[pa.Array]
bands: List[int]

# pylint: disable=too-many-arguments
def __init__(
self,
tile_shape: Tuple[int, int],
data: List[pa.Array],
geo_transform: gety.GeoTransform,
crs: str,
time: gety.TimeInterval,
bands: List[int],
):
'''Create a RasterTileStack2D object'''
(self.size_y, self.size_x) = tile_shape
self.data = data
self.geo_transform = geo_transform
self.crs = crs
self.time = time
self.bands = bands

def single_band(self, index: int) -> RasterTile2D:
'''Return a single band from the stack'''
return RasterTile2D(
(self.size_y, self.size_x),
self.data[index],
self.geo_transform,
self.crs,
self.time,
self.bands[index],
)

def to_numpy_masked_array_stack(self) -> np.ma.MaskedArray:
'''Return the raster stack as a 3D masked numpy array'''
arrays = [self.single_band(i).to_numpy_masked_array() for i in range(0, len(self.data))]
stack = np.stack(arrays, axis=0)
return stack

def to_xarray(self, clip_with_bounds: Optional[gety.SpatialBounds] = None) -> xr.DataArray:
'''Return the raster stack as an xarray.DataArray'''
arrays = [self.single_band(i).to_xarray(clip_with_bounds) for i in range(0, len(self.data))]
stack = xr.concat(arrays, dim='band')
return stack


async def tile_stream_to_stack_stream(raster_stream: AsyncIterator[RasterTile2D]) -> AsyncIterator[RasterTileStack2D]:

''' Convert a stream of raster tiles to stream of stacked tiles '''
store: List[RasterTile2D] = []
first_band: int = -1

async for tile in raster_stream:
if len(store) == 0:
first_band = tile.band
store.append(tile)

else:
# check things that should be the same for all tiles
assert tile.shape == store[0].shape, 'Tile shapes do not match'
# TODO: geo transform should be the same for all tiles
# tiles should have a tile position or global pixel position

# assert tile.geo_transform == store[0].geo_transform, 'Tile geo_transforms do not match'
assert tile.crs == store[0].crs, 'Tile crs do not match'

if tile.band == first_band:
assert tile.time.start >= store[0].time.start, 'Tile time intervals must be equal or increasing'

stack = [tile.data for tile in store]
tile_shape = store[0].shape
bands = [tile.band for tile in store]
geo_transforms = store[0].geo_transform
crs = store[0].crs
time = store[0].time

store = [tile]
yield RasterTileStack2D(tile_shape, stack, geo_transforms, crs, time, bands)

else:
assert tile.time == store[0].time, 'Time missmatch. ' + str(store[0].time) + ' != ' + str(tile.time)
store.append(tile)

if len(store) > 0:
tile_shape = store[0].shape
stack = [tile.data for tile in store]
bands = [tile.band for tile in store]
geo_transforms = store[0].geo_transform
crs = store[0].crs
time = store[0].time

store = []

yield RasterTileStack2D(tile_shape, stack, geo_transforms, crs, time, bands)
14 changes: 14 additions & 0 deletions geoengine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def to_api_dict(self) -> geoengine_openapi_client.TimeInterval:
def __datetime_to_iso_str(timestamp: np.datetime64) -> str:
return str(np.datetime_as_string(timestamp, unit='ms', timezone='UTC')).replace('Z', '+00:00')

def __eq__(self, other: Any) -> bool:
'''Check if two `TimeInterval` objects are equal.'''
if not isinstance(other, TimeInterval):
return False
return self.start == other.start and self.end == other.end


class SpatialResolution:
''''A spatial resolution.'''
Expand Down Expand Up @@ -1354,3 +1360,11 @@ def spatial_resolution(self) -> SpatialResolution:
x_resolution=abs(self.x_pixel_size),
y_resolution=abs(self.y_pixel_size)
)

def __eq__(self, other) -> bool:
'''Check if two geotransforms are equal'''
if not isinstance(other, GeoTransform):
return False

return self.x_min == other.x_min and self.y_max == other.y_max and \
self.x_pixel_size == other.x_pixel_size and self.y_pixel_size == other.y_pixel_size
15 changes: 13 additions & 2 deletions geoengine/workflow_builder/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ def sentinel2_cloud_free_band_custom_input(band_dataset: str, scl_dataset: str):
scl_source = operators.GdalSource(
dataset=scl_dataset
)

scl_source_u16 = operators.RasterTypeConversion(
source=scl_source,
output_data_type="U16"
)

# [sen2_mask == 3 |sen2_mask == 7 |sen2_mask == 8 | sen2_mask == 9 |sen2_mask == 10 |sen2_mask == 11 ]
cloud_free = operators.Expression(
expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
output_type="U16",
source=operators.RasterStacker([band_source, scl_source]),
source=operators.RasterStacker([band_source, scl_source_u16]),
)

return cloud_free
Expand All @@ -63,11 +69,16 @@ def sentinel2_cloud_free_ndvi_custom_input(nir_dataset: str, red_dataset: str, s
scl_source = operators.GdalSource(
dataset=scl_dataset
)
scl_source_u16 = operators.RasterTypeConversion(
source=scl_source,
output_data_type="U16"
)

# [sen2_mask == 3 |sen2_mask == 7 |sen2_mask == 8 | sen2_mask == 9 |sen2_mask == 10 |sen2_mask == 11 ]
cloud_free = operators.Expression(
expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { (A - B) / (A + B) }",
output_type="F32",
source=operators.RasterStacker([nir_source, red_source, scl_source]),
source=operators.RasterStacker([nir_source, red_source, scl_source_u16]),
)

return cloud_free
Expand Down
4 changes: 2 additions & 2 deletions geoengine/workflow_builder/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ class RasterTypeConversion(RasterOperator):
'''A RasterTypeConversion operator.'''

source: RasterOperator
output_data_type: Literal["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64"]
output_data_type: Literal["U8", "U16", "U32", "U64", "I8", "I16", "I32", "I64", "F32", "F64"]

def __init__(self,
source: RasterOperator,
output_data_type: Literal["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64"]
output_data_type: Literal["U8", "U16", "U32", "U64", "I8", "I16", "I32", "I64", "F32", "F64"]
):
'''Creates a new RasterTypeConversion operator.'''
self.source = source
Expand Down
84 changes: 70 additions & 14 deletions tests/test_workflow_builder_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,17 @@ def test_sentinel2_cloud_free_band(self):
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B02`"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
}
}
}
}
]
Expand Down Expand Up @@ -77,9 +85,17 @@ def test_sentinel2_ndvi(self):
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B04`"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
}
}
}
}
]
Expand Down Expand Up @@ -110,9 +126,17 @@ def test_sentinel2_cloud_free_band_custom_input(self):
"data": "data_band"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "scl_band"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "scl_band"
}
}
}
}
]
Expand Down Expand Up @@ -148,9 +172,17 @@ def test_sentinel2_cloud_free_ndvi_custom_input(self):
"data": "band4"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "scl_band"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "scl_band"
}
}
}
}
]
Expand Down Expand Up @@ -195,9 +227,17 @@ def test_s2_cloud_free_aggregated_band(self):
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B04`"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"
}
}
}
}
]
Expand Down Expand Up @@ -244,9 +284,17 @@ def test_s2_cloud_free_aggregated_band_custom_input(self):
"data": "band8"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "scl_band"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "scl_band"
}
}
}
}
]
Expand Down Expand Up @@ -298,9 +346,17 @@ def test_s2_cloud_free_aggregated_ndvi_custom_input(self):
"data": "band4"
}
}, {
"type": "GdalSource",
"type": "RasterTypeConversion",
"params": {
"data": "scl_band"
"outputDataType": "U16"
},
"sources": {
"raster": {
"type": "GdalSource",
"params": {
"data": "scl_band"
}
}
}
}
]
Expand Down

0 comments on commit 81d2a49

Please sign in to comment.