Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add band argument for interpolation #113

Merged
merged 6 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ocsmesh/hfun/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,6 @@ def add_region_constraint(
Add fixed-value or fixed-matrix constraint.
add_topo_func_constraint :
Addint constraint based on function of topography
add_courant_num_constraint :
Add constraint based on approximated Courant number
"""

if crs is None:
Expand Down
20 changes: 13 additions & 7 deletions ocsmesh/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def interpolate(
method: Literal['spline', 'linear', 'nearest'] = 'spline',
nprocs: Optional[int] = None,
info_out_path: Union[pathlib.Path, str, None] = None,
filter_by_shape: bool = False
filter_by_shape: bool = False,
band: int = 1,
) -> None:
"""Interplate values from raster inputs to the mesh nodes.

Expand All @@ -359,8 +360,10 @@ def interpolate(
Number of workers to use when interpolating data.
info_out_path : pathlike or str or None
Path for the output node interpolation information file
filter_by_shape : bool
filter_by_shape : bool, default=False
Flag for node filtering based on raster bbox or shape
band : int, default=1
The band from rasters to use for interpolation

Returns
-------
Expand All @@ -382,15 +385,15 @@ def interpolate(
_mesh_interpolate_worker,
[(self.vert2['coord'], self.crs,
_raster.tmpfile, _raster.chunk_size,
method, filter_by_shape)
method, filter_by_shape, band)
for _raster in raster]
)
pool.join()
else:
res = [_mesh_interpolate_worker(
self.vert2['coord'], self.crs,
_raster.tmpfile, _raster.chunk_size,
method, filter_by_shape)
method, filter_by_shape, band)
for _raster in raster]

values = self.msh_t.value.flatten()
Expand Down Expand Up @@ -2234,7 +2237,8 @@ def _mesh_interpolate_worker(
raster_path: Union[str, Path],
chunk_size: Optional[int],
method: Literal['spline', 'linear', 'nearest'] = "spline",
filter_by_shape: bool = False):
filter_by_shape: bool = False,
band: int = 1):
"""Interpolator worker function to be used in parallel calls

Parameters
Expand All @@ -2249,8 +2253,10 @@ def _mesh_interpolate_worker(
Chunk size for windowing over the raster.
method : {'spline', 'linear', 'nearest'}, default='spline'
Method of interpolation.
filter_by_shape : bool
filter_by_shape : bool, default=False
Flag for node filtering based on raster bbox or shape
band : int, default=1
The band from rasters to use for interpolation

Returns
-------
Expand Down Expand Up @@ -2281,7 +2287,7 @@ def _mesh_interpolate_worker(
xi = raster.get_x(window)
yi = raster.get_y(window)
# Use masked array to ignore missing values from DEM
zi = raster.get_values(window=window, masked=True)
zi = raster.get_values(window=window, masked=True, band=band)

if not filter_by_shape:
_idxs = np.logical_and(
Expand Down
15 changes: 13 additions & 2 deletions ocsmesh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,20 +2080,31 @@ def raster_from_numpy(
if not isinstance(crs, CRS):
crs = CRS.from_user_input(crs)

nbands = 1
if data.ndim == 3:
nbands = data.shape[2]
elif data.ndim != 2:
raise ValueError("Invalid data dimensions!")

with rio.open(
filename,
'w',
driver='GTiff',
height=data.shape[0],
width=data.shape[1],
count=1,
count=nbands,
dtype=data.dtype,
crs=crs,
transform=transform,
) as dst:
if isinstance(data, np.ma.MaskedArray):
dst.nodata = data.fill_value
dst.write(data, 1)

data = data.reshape(data.shape[0], data.shape[1], -1)
for i in range(nbands):
dst.write(data.take(i, axis=2), i + 1)




def msht_from_numpy(
Expand Down
65 changes: 64 additions & 1 deletion tests/api/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import tempfile
import unittest
import warnings
import shutil
from pathlib import Path

import numpy as np
from jigsawpy import jigsaw_msh_t
from pyproj import CRS
from shapely import geometry

from ocsmesh import utils
from ocsmesh.mesh.mesh import Mesh
from ocsmesh.mesh.mesh import Mesh, Raster



Expand Down Expand Up @@ -317,5 +319,66 @@ def test_specify_boundary_on_mesh_with_no_boundary(self):
self.assertEqual(bdry.open().iloc[0]['index_id'], [1, 2, 3])


class RasterInterpolation(unittest.TestCase):

def setUp(self):
self.tdir = Path(tempfile.mkdtemp())

msht1 = utils.create_rectangle_mesh(
nx=13, ny=5, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85),
holes=[],
)
msht1.crs = CRS.from_user_input(4326)
msht2 = utils.create_rectangle_mesh(
nx=11, ny=7, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85),
holes=[],
)
msht2.crs = CRS.from_user_input(4326)
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', category=UserWarning,
message='Input mesh has no CRS information'
)
self.mesh1 = Mesh(msht1)
self.mesh2 = Mesh(msht2)

self.rast = self.tdir / 'rast.tif'

rast_xy = np.mgrid[-74:-71:0.1, 40.9:40.5:-0.01]
rast_z = np.ones((rast_xy.shape[1], rast_xy.shape[2], 2))
rast_z[:, :, 1] = 2
utils.raster_from_numpy(
self.rast, rast_z, rast_xy, 4326
)


def tearDown(self):
shutil.rmtree(self.tdir)


def test_interpolation_io(self):
rast = Raster(self.rast)

self.mesh1.interpolate(rast)
self.assertTrue(np.isclose(self.mesh1.value, 1).all())

# TODO: Improve the assertion!
with self.assertRaises(Exception):
self.mesh1.interpolate(self.mesh2)


def test_interpolation_band(self):
rast = Raster(self.rast)

self.mesh1.interpolate(rast)
self.assertTrue(np.isclose(self.mesh1.value, 1).all())

self.mesh1.interpolate(rast, band=2)
self.assertTrue(np.isclose(self.mesh1.value, 2).all())


# TODO Add more interpolation tests


if __name__ == '__main__':
unittest.main()
39 changes: 39 additions & 0 deletions tests/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,45 @@ def test_data_masking(self):
self.assertEqual(rast.src.nodata, fill_value)


def test_multiband_raster_data(self):
nbands = 5
in_data = np.ones((3, 4, nbands))
for i in range(nbands):
in_data[:, :, i] *= i
in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1]
with tempfile.NamedTemporaryFile(suffix='.tiff') as tf:
utils.raster_from_numpy(
tf.name,
data=in_data,
mgrid=in_rast_xy,
crs=4326
)
rast = Raster(tf.name)
self.assertEqual(rast.count, nbands)
for i in range(nbands):
with self.subTest(band_number=i):
self.assertTrue(
(rast.get_values(band=i+1) == i).all()
)


def test_multiband_raster_invalid_io(self):
in_data = np.ones((3, 4, 5, 6))
in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1]
with tempfile.NamedTemporaryFile(suffix='.tiff') as tf:
with self.assertRaises(ValueError) as cm:
utils.raster_from_numpy(
tf.name,
data=in_data,
mgrid=in_rast_xy,
crs=4326
)
exc = cm.exception
self.assertRegex(str(exc).lower(), '.*dimension.*')




class ShapeToMeshT(unittest.TestCase):

def setUp(self):
Expand Down