Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting Fixes #4

Merged
merged 16 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions regridding/_conservative_ramshaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# cache=True,
)
def _conservative_ramshaw(
# values_input: np.ndarray,
# values_output: np.ndarray,
grid_input: tuple[np.ndarray, np.ndarray],
grid_output: tuple[np.ndarray, np.ndarray],
epsilon: float = 1e-10,
Expand All @@ -27,14 +25,9 @@ def _conservative_ramshaw(
weights.append((0., 0., 0.))

input_x, input_y = grid_input
# output_x, output_y = grid_output

shape_input = input_x.shape
# shape_output = np.broadcast_shapes(output_x.shape, output_y.shape)

grids_sweep = grid_input, grid_output
grids_static = grid_output, grid_input
grids_input = "sweep", "static"
axes = 0, 1

# k = slice(None, -1)
Expand All @@ -55,23 +48,35 @@ def _conservative_ramshaw(

# values_input = values_input / area_input

for grid_sweep, grid_static, grid_input in zip(grids_sweep, grids_static, grids_input):
grid_static_x, grid_static_y = grid_static
grid_sweep_x, grid_sweep_y = grid_sweep
for axis in axes:
_sweep_axis(
# values_input=values_input,
# values_output=values_output,
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input=grid_input,
epsilon=epsilon,
weights=weights,
)
grid_static_x, grid_static_y = grid_output
grid_sweep_x, grid_sweep_y = grid_input
for axis in axes:
_sweep_axis(
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input="sweep",
epsilon=epsilon,
weights=weights,
)

grid_static_x, grid_static_y = grid_input
grid_sweep_x, grid_sweep_y = grid_output
for axis in axes:
_sweep_axis(
area_input=area_input,
grid_sweep_x=grid_sweep_x,
grid_sweep_y=grid_sweep_y,
grid_static_x=grid_static_x,
grid_static_y=grid_static_y,
axis=axis,
grid_input="static",
epsilon=epsilon,
weights=weights,
)

# return values_output
return weights
Expand Down
17 changes: 12 additions & 5 deletions regridding/_regrid/_regrid_from_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def regrid_from_weights(
:func:`regridding.regrid_from_weights`
"""

shape_input = np.broadcast_shapes(shape_input, values_input.shape)
values_input = np.broadcast_to(values_input, shape=shape_input, subok=True)
jacobdparker marked this conversation as resolved.
Show resolved Hide resolved
shape_input = np.broadcast_shapes(values_input.shape, shape_input)

ndim_input = len(shape_input)
axis_input = _util._normalize_axis(axis_input, ndim=ndim_input)

shape_orthogonal = (
1 if i in axis_input else shape_input[i] for i in range(-len(shape_input), 0)
)
weights = np.broadcast_to(np.array(weights), shape_orthogonal)
values_input = np.broadcast_to(values_input, shape_input)

if values_output is None:
shape_output = np.broadcast_shapes(
shape_output,
Expand Down Expand Up @@ -85,10 +91,11 @@ def regrid_from_weights(

shape_output_tmp = values_output.shape

weights = numba.typed.List(weights.reshape(-1))
values_input = values_input.reshape(-1, *shape_input_numba)
values_output = values_output.reshape(-1, *shape_output_numba)

weights = numba.typed.List(weights.reshape(-1))

values_input = np.ascontiguousarray(values_input)
values_output = np.ascontiguousarray(values_output)

Expand All @@ -105,17 +112,17 @@ def regrid_from_weights(
return values_output


@numba.njit()
@numba.njit(parallel=True)
def _regrid_from_weights(
weights: numba.typed.List,
values_input: np.ndarray,
values_output: np.ndarray,
) -> None:

for d in numba.prange(len(weights)):
weights_d = weights[d]
values_input_d = values_input[d].reshape(-1)
values_output_d = values_output[d].reshape(-1)

for w in range(len(weights_d)):
i_input, i_output, weight = weights_d[w]
values_output_d[int(i_output)] += weight * values_input_d[int(i_input)]
95 changes: 76 additions & 19 deletions regridding/_regrid/_tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import numpy as np
import regridding

x = np.linspace(-1, 1, num=10)
y = np.linspace(-1, 1, num=11)
x_broadcasted, y_broadcasted = np.meshgrid(
x,
y,
indexing="ij",
)

new_y = np.linspace(-1, 1, num=5)
new_x = np.linspace(-1, 1, num=6)

new_x_broadcasted, new_y_broadcasted = np.meshgrid(
x,
new_y,
indexing="ij",
)

new_x_broadcasted_2, new_y_broadcasted_2 = np.meshgrid(
new_x,
y,
indexing="ij",
)


@pytest.mark.parametrize(
argnames="coordinates_input,coordinates_output,values_input,values_output,axis_input,axis_output,result_expected",
Expand All @@ -24,6 +47,33 @@
None,
np.square(np.linspace(-1, 1, num=11)),
),
(
(y,),
(new_y,),
x_broadcasted + y_broadcasted,
None,
(~0,),
(~0,),
new_x_broadcasted + new_y_broadcasted,
),
(
(x[..., np.newaxis],),
(new_x[..., np.newaxis],),
x_broadcasted + y_broadcasted,
None,
(0,),
(0,),
new_x_broadcasted_2 + new_y_broadcasted_2,
),
(
(x[..., np.newaxis],),
(0.1 * new_x[..., np.newaxis] + 0.001 * new_y,),
x[..., np.newaxis],
None,
(0,),
(0,),
0.1 * new_x[..., np.newaxis] + 0.001 * new_y,
),
],
)
def test_regrid_multilinear_1d(
Expand All @@ -46,35 +96,34 @@ def test_regrid_multilinear_1d(
)
assert isinstance(result, np.ndarray)
assert np.issubdtype(result.dtype, float)
assert np.all(result == result_expected)
assert np.allclose(result, result_expected)


@pytest.mark.parametrize(
argnames="coordinates_input, values_input, axis_input",
argnames="coordinates_input, values_input, axis_input, coordinates_output, values_output, axis_output",
argvalues=[
(
np.meshgrid(
np.linspace(-1, 1, num=10),
np.linspace(-1, 1, num=11),
indexing="ij",
),
(x_broadcasted, y_broadcasted),
np.random.normal(size=(10 - 1, 11 - 1)),
None,
(1.1 * x_broadcasted + 0.01, 1.2 * y_broadcasted + 0.01),
None,
None,
),
],
)
@pytest.mark.parametrize(
argnames="coordinates_output, values_output, axis_output",
argvalues=[
(
np.meshgrid(
1.1 * np.linspace(-1, 1, num=10) + 0.001,
1.2 * np.linspace(-1, 1, num=11) + 0.001,
indexing="ij",
(
x_broadcasted[..., np.newaxis] + np.array([0, 0.001]),
y_broadcasted[..., np.newaxis] + np.array([0, 0.001]),
),
np.random.normal(size=(x.shape[0] - 1, y.shape[0] - 1, 2)),
(0, 1),
(
1.1 * (x_broadcasted[..., np.newaxis] + np.array([0, 0.001])) + 0.01,
1.2 * (y_broadcasted[..., np.newaxis] + np.array([0, 0.01])) + 0.001,
),
None,
None,
)
(0, 1),
),
],
)
def test_regrid_conservative_2d(
Expand All @@ -95,6 +144,14 @@ def test_regrid_conservative_2d(
method="conservative",
)

result_shape = np.array(np.broadcast(*coordinates_output).shape)

if axis_input is None:
result_shape = result_shape - 1
else:
for ax in axis_input:
result_shape[ax] = result_shape[ax] - 1

assert np.issubdtype(result.dtype, float)
assert result.shape == tuple(np.array(np.broadcast(*coordinates_output).shape) - 1)
assert result.shape == tuple(result_shape)
assert np.isclose(result.sum(), values_input.sum())
14 changes: 7 additions & 7 deletions regridding/_weights/_weights_conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def _weights_conservative(
weights = np.empty(shape_orthogonal, dtype=numba.typed.List)

for index in np.ndindex(*shape_orthogonal):
index_vertices_input = list(index)
index_vertices_input = list(reversed(index))

for ax in axis_input:
index_vertices_input.insert(ax, slice(None))
index_vertices_input = tuple(index_vertices_input)
index_vertices_input.insert(~ax, slice(None))
index_vertices_input = tuple(reversed(index_vertices_input))

index_vertices_output = list(index)
index_vertices_output = list(reversed(index))
for ax in axis_output:
index_vertices_output.insert(ax, slice(None))
index_vertices_output = tuple(index_vertices_output)
index_vertices_output.insert(~ax, slice(None))
index_vertices_output = tuple(reversed(index_vertices_output))

if len(axis_input) == 1:
raise NotImplementedError("1D regridding not supported")

elif len(axis_input) == 2:
coordinates_input_x, coordinates_input_y = coordinates_input
coordinates_output_x, coordinates_output_y = coordinates_output

weights[index] = _conservative_ramshaw(
grid_input=(
coordinates_input_x[index_vertices_input],
Expand Down
Loading