Skip to content

Commit

Permalink
Merge pull request #100 from kmnhan/fix-const-coords
Browse files Browse the repository at this point in the history
Support DataArray with constant coordinates in ImageTool
  • Loading branch information
kmnhan authored Feb 18, 2025
2 parents 6d9abca + f2c809b commit d58cfde
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/user-guide/io.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@
" coords=darr.coords,\n",
" dims=darr.dims,\n",
" attrs=darr.attrs,\n",
" name=darr.name,\n",
" )\n",
"\n",
" return darr\n",
Expand Down
12 changes: 8 additions & 4 deletions src/erlab/interactive/imagetool/fastslicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def _transposed(arr: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
@numba.njit(numba.boolean(numba.float64[::1]), cache=True)
def _is_uniform(arr: npt.NDArray[np.float64]) -> bool:
dif = np.diff(arr)
if dif[0] == 0.0:
# Treat constant coordinate array as non-uniform
return False
return np.allclose(dif, dif[0], rtol=3e-05, atol=3e-05, equal_nan=True)


Expand All @@ -81,15 +84,16 @@ def _is_uniform(arr: npt.NDArray[np.float64]) -> bool:
],
cache=True,
)
def _index_of_value_nonuniform(
arr: npt.NDArray[np.floating], val: np.floating
) -> np.int_:
def _index_of_value_nonuniform(arr: npt.NDArray[np.floating], val: np.floating) -> int:
return np.searchsorted((arr[:-1] + arr[1:]) / 2, val)


@numba.njit(
[numba.float64(numba.float32[::1]), numba.float64(numba.float64[::1])], cache=True
)
def _avg_nonzero_abs_diff(arr: npt.NDArray[np.floating]) -> np.floating:
def _avg_nonzero_abs_diff(arr: npt.NDArray[np.floating]) -> float:
diff = np.diff(arr)

if np.all(diff == 0.0): # Prevent division by zero
return 0.0
return np.mean(diff[diff != 0])
5 changes: 5 additions & 0 deletions src/erlab/interactive/imagetool/manager/_mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def resizeEvent(self, event: QtGui.QResizeEvent | None) -> None:
super().resizeEvent(event)
self.fitInView(self._pixmapitem)

def wheelEvent(self, event: QtGui.QWheelEvent | None) -> None:
# Disable scrolling by ignoring wheel events
if event:
event.ignore()


class ImageToolManager(QtWidgets.QMainWindow):
"""The ImageToolManager window.
Expand Down
2 changes: 2 additions & 0 deletions src/erlab/interactive/imagetool/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ def get_significant(self, axis: int, uniform: bool = False) -> int:
step = self.incs_uniform[axis]
else:
step = self.incs[axis]
if step == 0:
return 3 # Default to 3 decimal places for zero step size
return erlab.utils.array.effective_decimals(step)

def add_cursor(self, like_cursor: int = -1, update: bool = True) -> None:
Expand Down
5 changes: 4 additions & 1 deletion src/erlab/io/plugins/maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def _parse_attr(v) -> str | int | float:
data = groups["/2D_Data"]
data = data[next(iter(data.data_vars))]
data = xr.DataArray(
np.zeros(data.shape, dtype=np.uint8), dims=data.dims, attrs=data.attrs
np.zeros(data.shape, dtype=np.uint8),
dims=data.dims,
attrs=data.attrs,
name=data.name,
)
else:
# Create or load cache
Expand Down
14 changes: 13 additions & 1 deletion src/erlab/io/plugins/ssrl52.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,18 @@ class SSRL52Loader(LoaderBase):
"sample_workfunction": "WorkFunction",
}

coordinate_attrs = ("beta", "delta", "chi", "xi", "hv", "x", "y", "z")
coordinate_attrs = (
"beta",
"delta",
"chi",
"xi",
"hv",
"x",
"y",
"z",
"polarization",
"sample_temp",
)

additional_attrs: typing.ClassVar[dict] = {
"configuration": 3,
Expand Down Expand Up @@ -267,6 +278,7 @@ def load_single(
coords=darr.coords,
dims=darr.dims,
attrs=darr.attrs,
name=darr.name,
)

darr = darr.assign_attrs(attrs)
Expand Down
7 changes: 7 additions & 0 deletions src/erlab/utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _wrapper(*args, **kwargs) -> typing.Any:
def is_uniform_spaced(arr: npt.NDArray, **kwargs) -> bool:
"""Check if the given array is uniformly spaced.
Constant arrays are also considered as uniformly spaced.
Parameters
----------
arr : array-like
Expand Down Expand Up @@ -285,11 +287,16 @@ def effective_decimals(step_or_coord: float | np.floating | npt.NDArray) -> int:
int
The effective number of decimal places, calculated as the order of magnitude of
``step`` plus one.
If the step size is zero, a default value of 3 is returned.
"""
if isinstance(step_or_coord, np.ndarray):
step = step_or_coord[1] - step_or_coord[0]
else:
step = step_or_coord

if step == 0.0:
return 3
return int(np.clip(np.ceil(-np.log10(np.abs(step)) + 1), a_min=0, a_max=None))


Expand Down
6 changes: 3 additions & 3 deletions src/erlab/utils/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ def _format_array_values(val: npt.NDArray) -> str:
val = val.squeeze()

if val.ndim == 1:
if len(val) == 2:
return f"[{format_value(val[0])}, {format_value(val[1])}]"

if erlab.utils.array.is_uniform_spaced(val):
if val[0] == val[-1]:
return format_value(val[0])
Expand All @@ -305,6 +302,9 @@ def _format_array_values(val: npt.NDArray) -> str:
if val[0] == val[-1]:
return format_value(val[0])

if len(val) == 2:
return f"[{format_value(val[0])}, {format_value(val[1])}]"

return f"{format_value(val[0])} to {format_value(val[-1])}"

mn, mx = tuple(format_value(v) for v in (np.nanmin(val), np.nanmax(val)))
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from erlab.interactive.utils import _WaitDialog
from erlab.io.exampledata import generate_data_angles, generate_gold_edge

DATA_COMMIT_HASH = "26535e727236f220a4424538ba00b4b7a2b9666a"
DATA_COMMIT_HASH = "549959fc88b4875863e2ec386ac7a95035a8c3ea"
"""The commit hash of the commit to retrieve from `kmnhan/erlabpy-data`."""

DATA_KNOWN_HASH = "f8d0a245747f6f899dc417db86f6ab64cf2c2dc7784aade8416c1c5d5c6c8ca4"
DATA_KNOWN_HASH = "ce9016b72f492f7d3aff5583256cc9d901de97481d5a141f124c35c26e2c7c77"
"""The SHA-256 checksum of the `.tar.gz` file."""

log = logging.getLogger(__name__)
Expand Down
13 changes: 12 additions & 1 deletion tests/interactive/test_imagetool.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@
"beta": np.arange(5),
},
),
"3D_const_nonuniform": xr.DataArray(
np.arange(125).reshape((5, 5, 5)),
dims=["x", "eV", "beta"],
coords={
"x": np.array([0.1, 0.1, 0.1, 0.1, 0.1]),
"eV": np.arange(5),
"beta": np.arange(5),
},
),
}


Expand Down Expand Up @@ -229,7 +238,9 @@ def test_itool_general(qtbot, move_and_compare_values) -> None:
win.close()


@pytest.mark.parametrize("test_data_type", ["2D", "3D", "3D_nonuniform"])
@pytest.mark.parametrize(
"test_data_type", ["2D", "3D", "3D_nonuniform", "3D_const_nonuniform"]
)
@pytest.mark.parametrize("condition", ["unbinned", "binned"])
def test_itool_tools(qtbot, test_data_type, condition) -> None:
data = _TEST_DATA[test_data_type]
Expand Down
1 change: 1 addition & 0 deletions tests/io/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def load_single(self, file_path, without_values=False):
coords=darr.coords,
dims=darr.dims,
attrs=darr.attrs,
name=darr.name,
)

return darr
Expand Down

0 comments on commit d58cfde

Please sign in to comment.