Skip to content

Commit

Permalink
feat(analysis.transform): add symmetrize (#97)
Browse files Browse the repository at this point in the history
Adds a new method `erlab.analysis.transform.symmetrize` for symmetrizing data across a single coordinate.
  • Loading branch information
kmnhan authored Feb 17, 2025
1 parent 567d989 commit aefb966
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 5 deletions.
124 changes: 120 additions & 4 deletions src/erlab/analysis/transform.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
"""Transformations."""

__all__ = ["rotate", "rotateinplane", "rotatestackinplane", "shift"]
__all__ = ["rotate", "rotateinplane", "rotatestackinplane", "shift", "symmetrize"]

import itertools
import typing
import warnings
from collections.abc import Hashable, Iterable, Mapping

import numpy as np
import scipy.ndimage
import scipy.special
import scipy
import xarray as xr

import erlab

if typing.TYPE_CHECKING:
import scipy.ndimage
import scipy.special # noqa: TC004


def rotate(
darr: xr.DataArray,
angle: float,
axes: tuple[int, int] | tuple[Hashable, Hashable] = (0, 1),
center: tuple[float, float] | Mapping[Hashable, float] = (0.0, 0.0),
*,
reshape: bool = True,
order: int = 1,
mode="constant",
Expand Down Expand Up @@ -224,6 +228,7 @@ def shift(
darr: xr.DataArray,
shift: float | xr.DataArray,
along: str,
*,
shift_coords: bool = False,
**shift_kwargs,
) -> xr.DataArray:
Expand Down Expand Up @@ -270,6 +275,7 @@ def shift(
-------
>>> import xarray as xr
>>> import numpy as np
>>> import erlab.analysis as era
>>> darr = xr.DataArray(
... np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(float), dims=["x", "y"]
Expand All @@ -278,7 +284,9 @@ def shift(
>>> shifted = era.transform.shift(darr, shift_arr, along="y")
>>> print(shifted)
<xarray.DataArray (x: 3, y: 3)> Size: 72B
nan 1.0 2.0 4.0 5.0 6.0 nan nan 7.0
array([[nan, 1., 2.],
[ 4., 5., 6.],
[nan, nan, 7.]])
Dimensions without coordinates: x, y
"""
shift_kwargs.setdefault("order", 1)
Expand Down Expand Up @@ -351,6 +359,114 @@ def shift(
return out


def symmetrize(
darr: xr.DataArray,
dim: Hashable,
*,
center: float = 0.0,
part: typing.Literal["both", "below", "above"] = "both",
) -> xr.DataArray:
"""
Symmetrize a DataArray along a specified dimension around a given center.
This function takes an input DataArray and symmetrizes its values along the
specified dimension by reflecting and summing the data in regions below and above a
given center.
The operation assumes that the coordinate corresponding to the dimension is evenly
spaced. Internally, the function interpolates the data to a shifted coordinate grid
to align with the nearest grid point, performs the reflection, and concatenates the
resulting halves.
Parameters
----------
darr : DataArray
The input xarray DataArray to be symmetrized. Its coordinate along the specified
dimension must be uniformly spaced.
dim : Hashable
The dimension along which to perform the symmetrization.
center : float, optional
The central value about which the data is symmetrized (default is 0.0).
part : {'both', 'below', 'above'}, optional
The part of the symmetrized data to return. If 'both', the full symmetrized data
is returned. If 'below', only the part below the center is returned. If 'above',
only the part above the center is returned.
Returns
-------
DataArray
A symmetrized DataArray where each value is the sum of its original and
reflected counterpart.
Examples
--------
>>> import xarray as xr
>>> import numpy as np
>>> import erlab.analysis as era
>>> # Create a sample DataArray with uniform coordinates.
>>> da = xr.DataArray(
... np.array([1, 2, 3, 4, 5, 6]), dims="x", coords={"x": np.linspace(-2, 2, 6)}
... )
>>> sym_da = era.transform.symmetrize(da, dim="x", center=0.0)
>>> print(sym_da)
<xarray.DataArray (x: 6)> Size: 48B
array([2., 4., 6., 6., 4., 2.])
Coordinates:
* x (x) float64 48B -2.0 -1.2 -0.4 0.4 1.2 2.0
"""
if not erlab.utils.array.is_dims_uniform(darr, (dim,)):
raise ValueError(f"Coordinate along dimension {dim} must be uniformly spaced")

# Ensure coord is increasing
out = darr.copy().sortby(dim)

with xr.set_options(keep_attrs=True):
coord: xr.DataArray = darr[dim]

step = float(np.abs(coord[1] - coord[0]))
closest_val = (
float(typing.cast(xr.DataArray, np.abs(coord - center)).idxmin(dim))
- center
) # displacement relative to nearest grid point

shifted_coords = coord.values - closest_val - step / 2
shifted_coords = np.append(shifted_coords, shifted_coords[-1] + step)

out_shifted = out.interp({dim: shifted_coords}, assume_sorted=True).dropna(dim)

below = out_shifted.where(out_shifted[dim] < center, drop=True)
above = out_shifted.where(out_shifted[dim] > center, drop=True)

# Flip coord along center
above = above.assign_coords({dim: center - (above[dim] - center)}).sortby(dim)

# Ensure flipped coord matches exactly with original
above = above.assign_coords({dim: below[dim][-len(above) :]})

# Symmetrize
sym_below = below + above

# Retain coordinate attributes
sym_below = sym_below.assign_coords(
{dim: sym_below[dim].assign_attrs(coord.attrs)}
)

if part == "below":
return sym_below

# Flip symmetrized data
sym_above = (
sym_below.copy()
.assign_coords({dim: center - (sym_below[dim] - center)})
.sortby(dim)
)

if part == "above":
return sym_above

return xr.concat([sym_below, sym_above], dim=dim)


def rotateinplane(data: xr.DataArray, rotate, **interp_kwargs):
"""Rotate a 2D DataArray in the plane defined by the two dimensions.
Expand Down
52 changes: 51 additions & 1 deletion tests/analysis/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr
import xarray.testing

from erlab.analysis.transform import rotate, shift
from erlab.analysis.transform import rotate, shift, symmetrize


def test_rotate() -> None:
Expand Down Expand Up @@ -139,4 +139,54 @@ def test_shift() -> None:
)

# Check if the shifted array matches the expected result

assert np.allclose(shifted, expected, equal_nan=True)


def test_symmetrize_both() -> None:
# Test symmetrize returns full (both) symmetrized DataArray.
da = xr.DataArray(
np.array([1, 2, 3, 3, 2, 1], dtype=float),
dims="x",
coords={"x": np.linspace(-2, 2, 6)},
)
sym_da = symmetrize(da, "x", center=0.0, part="both")
expected = np.array([2, 4, 6, 6, 4, 2], dtype=float)
np.testing.assert_allclose(sym_da.values, expected, rtol=1e-5)


def test_symmetrize_below() -> None:
# Test symmetrize returns only the lower half.
da = xr.DataArray(
np.array([1, 2, 3, 3, 2, 1], dtype=float),
dims="x",
coords={"x": np.linspace(-2, 2, 6)},
)
sym_da = symmetrize(da, "x", center=0.0, part="below")
expected = np.array([2, 4, 6], dtype=float)
np.testing.assert_allclose(sym_da.values, expected, rtol=1e-5)


def test_symmetrize_above() -> None:
# Test symmetrize returns only the upper half (reflected).
da = xr.DataArray(
np.array([1, 2, 3, 3, 2, 1], dtype=float),
dims="x",
coords={"x": np.linspace(-2, 2, 6)},
)
sym_da = symmetrize(da, "x", center=0.0, part="above")
expected = np.array([6, 4, 2], dtype=float)
np.testing.assert_allclose(sym_da.values, expected, rtol=1e-5)


def test_symmetrize_non_uniform() -> None:
# Test that symmetrize raises an error when the coordinate is non-uniform.
da = xr.DataArray(
np.array([1, 2, 3, 4], dtype=float),
dims="x",
coords={"x": np.array([0.0, 1.0, 3.0, 6.0])}, # non-evenly spaced
)
with pytest.raises(
ValueError, match="Coordinate along dimension x must be uniformly spaced"
):
symmetrize(da, "x", center=0.0)

0 comments on commit aefb966

Please sign in to comment.