Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate axes types v0.4 #124

Merged
merged 17 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions ome_zarr/axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Axes class for validating and transforming axes
"""
from typing import Any, Dict, List, Union

from .format import Format

KNOWN_AXES = {"x": "space", "y": "space", "z": "space", "c": "channel", "t": "time"}


class Axes:
def __init__(self, axes: Union[List[str], List[Dict[str, str]]]) -> None:
self.axes = self._axes_to_dicts(axes)

def validate(self, fmt: Format) -> None:
sbesson marked this conversation as resolved.
Show resolved Hide resolved

# check names (only enforced for version 0.3)
if fmt.version == "0.3":
self._validate_axes_03()
sbesson marked this conversation as resolved.
Show resolved Hide resolved
return

self._validate_axes_types()

def get_axes(self, fmt: Format) -> Union[List[str], List[Dict[str, str]]]:
sbesson marked this conversation as resolved.
Show resolved Hide resolved
if fmt.version == "0.3":
return self._get_names()
return self.axes

@staticmethod
def _axes_to_dicts(
axes: Union[List[str], List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Returns a list of axis dicts with name and type"""
axes_dicts = []
for axis in axes:
if isinstance(axis, str):
axis_dict = {"name": axis}
if axis in KNOWN_AXES:
axis_dict["type"] = KNOWN_AXES[axis]
axes_dicts.append(axis_dict)
else:
axes_dicts.append(axis)
return axes_dicts

def _validate_axes_types(self) -> None:
"""
Validate the axes types according to the spec, version 0.4+
"""
axes_types = [axis.get("type") for axis in self.axes]
known_types = list(KNOWN_AXES.values())
unknown_types = [atype for atype in axes_types if atype not in known_types]
if len(unknown_types) > 1:
raise ValueError(
"Too many unknown axes types. 1 allowed, found: %s" % unknown_types
)

def _last_index(item: str, item_list: List[Any]) -> int:
return max(loc for loc, val in enumerate(item_list) if val == item)

if "time" in axes_types and _last_index("time", axes_types) > 0:
raise ValueError("'time' axis must be first dimension only")

if axes_types.count("channel") > 1:
raise ValueError("Only 1 axis can be type 'channel'")

if "channel" in axes_types and _last_index(
"channel", axes_types
) > axes_types.index("space"):
raise ValueError("'space' axes must come after 'channel'")

def _get_names(self) -> List[str]:
"""Returns a list of axis names"""
axes_names = []
for axis in self.axes:
if "name" not in axis:
raise ValueError("Axis Dict %s has no 'name'" % axis)
axes_names.append(axis["name"])
return axes_names

def _validate_axes_03(self) -> None:

val_axes = tuple(self._get_names())
if len(val_axes) == 2:
if val_axes != ("y", "x"):
raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}")
elif len(val_axes) == 3:
if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]:
raise ValueError(
"3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')"
" or ('t', 'y', 'x'), not %s" % (val_axes,)
)
elif len(val_axes) == 4:
if val_axes not in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
]:
raise ValueError("4D data must have axes tzyx or czyx or tcyx")
else:
if val_axes != ("t", "c", "z", "y", "x"):
raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')")
22 changes: 21 additions & 1 deletion ome_zarr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@
LOGGER = logging.getLogger("ome_zarr.format")


def format_from_version(version: str) -> "Format":

for fmt in format_implementations():
if fmt.version == version:
return fmt
raise ValueError(f"Version {version} not recognized")


def format_implementations() -> Iterator["Format"]:
"""
Return an instance of each format implementation, newest to oldest.
"""
yield FormatV04()
yield FormatV03()
yield FormatV02()
yield FormatV01()
Expand Down Expand Up @@ -136,4 +145,15 @@ def version(self) -> str:
return "0.3"


CurrentFormat = FormatV03
class FormatV04(FormatV03):
"""
Changelog: axes is list of dicts,
introduce transformations in multiscales (Nov 2021)
"""

@property
def version(self) -> str:
return "0.4"


CurrentFormat = FormatV04
18 changes: 12 additions & 6 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import numpy as np
from dask import delayed

from .format import format_from_version
from .io import ZarrLocation
from .types import JSONDict
from .writer import validate_axes
sbesson marked this conversation as resolved.
Show resolved Hide resolved

LOGGER = logging.getLogger("ome_zarr.reader")

Expand Down Expand Up @@ -275,17 +277,16 @@ def matches(zarr: ZarrLocation) -> bool:
def __init__(self, node: Node) -> None:
super().__init__(node)

axes_values = {"t", "c", "z", "y", "x"}
try:
multiscales = self.lookup("multiscales", [])
version = multiscales[0].get(
"version", "0.1"
) # should this be matched with Format.version?
datasets = multiscales[0]["datasets"]
# axes field was introduced in 0.3, before all data was 5d
axes = tuple(multiscales[0].get("axes", ["t", "c", "z", "y", "x"]))
if len(set(axes) - axes_values) > 0:
raise RuntimeError(f"Invalid axes names: {set(axes) - axes_values}")
axes = multiscales[0].get("axes")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the main implication of not setting a default value here?
For 0.1/0.2 data, this means, the node.metadata["axes"] might be None as opposed to ["t", "c", "z", "y", "x"] previously i.e. we are preserving the value stored in the metadata? Is there an impact on clients relying on node.metadata["axes"]?

fmt = format_from_version(version)
# Raises ValueError if not valid
validate_axes(None, axes, fmt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not consuming the return values of validate_axes, the goal here is "only" to validate an axes but not modify it ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - good point. Looking at this again, I realise that validate_axes(None, axes, fmt) is really designed for writing. So, what's returned will be valid, even if what it's passed isn't valid. E.g. it will convert "tczyx" to an axis array, and will allow v0.3 to have None if 2D or 5D.
What we really want here is Axes(axes, fmt).validate()

node.metadata["axes"] = axes
datasets = [d["path"] for d in datasets]
self.datasets: List[str] = datasets
Expand All @@ -301,7 +302,12 @@ def __init__(self, node: Node) -> None:
for c in data.chunks
]
LOGGER.info("resolution: %s", resolution)
LOGGER.info(" - shape %s = %s", axes, data.shape)
axes_names = None
if axes is not None:
axes_names = tuple(
axis if isinstance(axis, str) else axis["name"] for axis in axes
)
LOGGER.info(" - shape %s = %s", axes_names, data.shape)
LOGGER.info(" - chunks = %s", chunk_sizes)
LOGGER.info(" - dtype = %s", data.dtype)
node.data.append(data)
Expand Down
76 changes: 31 additions & 45 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@

"""
import logging
from typing import Any, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import zarr

from .axes import Axes
from .format import CurrentFormat, Format
from .scale import Scaler
from .types import JSONDict

LOGGER = logging.getLogger("ome_zarr.writer")


def _validate_axes_names(
ndim: int, axes: Union[str, List[str]] = None, fmt: Format = CurrentFormat()
) -> Union[None, List[str]]:
"""Returns validated list of axes names or raise exception if invalid"""
def validate_axes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least I find the following lines

axes_obj = Axes(axes)
axes_obj.validate(fmt)

clarifies a lot of the logic happening here. It brings the question of whether additional logic should be moved to the constructor. Said otherwise, what is the added value of calling the validate_axes API vs the two-liner:

axes = Axes(axes, fmt=fmt, ndim=ddim)
axes.validate()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the outcome of #124 (comment). Should the name of the method be updated to reflect this is a writer/constructor rather than a validator? Since this API is moved to be a public API, this is increasingly important. Alternatively, we can keep it prefixed with _ for now.

ndim: int = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
fmt: Format = CurrentFormat(),
) -> Union[None, List[str], List[Dict[str, str]]]:
"""Returns list of axes valid for fmt.version or raise exception if invalid"""

if fmt.version in ("0.1", "0.2"):
if axes is not None:
LOGGER.info("axes ignored for version 0.1 or 0.2")
return None

# handle version 0.3...
# We can guess axes for 2D and 5D data
if axes is None:
if ndim == 2:
axes = ["y", "x"]
Expand All @@ -37,45 +40,27 @@ def _validate_axes_names(
"axes must be provided. Can't be guessed for 3D or 4D data"
)

# axes may be string e.g. "tczyx"
if isinstance(axes, str):
axes = list(axes)

if len(axes) != ndim:
raise ValueError("axes length must match number of dimensions")
_validate_axes(axes)
return axes

if ndim is not None and len(axes) != ndim:
raise ValueError(
f"axes length ({len(axes)}) must match number of dimensions ({ndim})"
)

def _validate_axes(axes: List[str], fmt: Format = CurrentFormat()) -> None:
axes_obj = Axes(axes)
axes_obj.validate(fmt)

val_axes = tuple(axes)
if len(val_axes) == 2:
if val_axes != ("y", "x"):
raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}")
elif len(val_axes) == 3:
if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]:
raise ValueError(
"3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')"
" or ('t', 'y', 'x'), not %s" % (val_axes,)
)
elif len(val_axes) == 4:
if val_axes not in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
]:
raise ValueError("4D data must have axes tzyx or czyx or tcyx")
else:
if val_axes != ("t", "c", "z", "y", "x"):
raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')")
return axes_obj.get_axes(fmt)


def write_multiscale(
pyramid: List,
group: zarr.Group,
chunks: Union[Tuple[Any, ...], int] = None,
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
) -> None:
"""
Write a pyramid with multiscale metadata to disk.
Expand All @@ -93,13 +78,13 @@ def write_multiscale(
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
axes: str or list of str or list of dict
List of axes dicts, or names. Not needed for v0.1 or v0.2
or if 2D. Otherwise this must be provided
"""

dims = len(pyramid[0].shape)
axes = _validate_axes_names(dims, axes, fmt)
axes = validate_axes(dims, axes, fmt)

paths = []
for path, dataset in enumerate(pyramid):
Expand All @@ -113,7 +98,7 @@ def write_multiscales_metadata(
group: zarr.Group,
paths: List[str],
fmt: Format = CurrentFormat(),
axes: List[str] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
will-moore marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Write the multiscales metadata in the group.
Expand Down Expand Up @@ -142,8 +127,9 @@ def write_multiscales_metadata(
if fmt.version in ("0.1", "0.2"):
LOGGER.info("axes ignored for version 0.1 or 0.2")
else:
_validate_axes(axes, fmt)
multiscales[0]["axes"] = axes
axes = validate_axes(axes=axes, fmt=fmt)
if axes is not None:
multiscales[0]["axes"] = axes
group.attrs["multiscales"] = multiscales


Expand All @@ -154,7 +140,7 @@ def write_image(
byte_order: Union[str, List[str]] = "tczyx",
scaler: Scaler = Scaler(),
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
axes: Union[str, List[str], List[Dict[str, str]]] = None,
**metadata: JSONDict,
) -> None:
"""Writes an image to the zarr store according to ome-zarr specification
Expand All @@ -179,9 +165,9 @@ def write_image(
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
axes: str or list of str or list of dict
List of axes dicts, or names. Not needed for v0.1 or v0.2
or if 2D. Otherwise this must be provided
"""

if image.ndim > 5:
Expand All @@ -195,7 +181,7 @@ def write_image(
axes = None

# check axes before trying to scale
_validate_axes_names(image.ndim, axes, fmt)
validate_axes(image.ndim, axes, fmt)

if chunks is not None:
chunks = _retuple(chunks, image.shape)
Expand Down
Loading