diff --git a/lib/galaxy/tool_util/parser/util.py b/lib/galaxy/tool_util/parser/util.py index 228312b1af46..46cd012809a0 100644 --- a/lib/galaxy/tool_util/parser/util.py +++ b/lib/galaxy/tool_util/parser/util.py @@ -5,6 +5,7 @@ DEFAULT_METRIC = "mae" DEFAULT_EPS = 0.01 +DEFAULT_PIN_LABELS = None def is_dict(item): diff --git a/lib/galaxy/tool_util/parser/xml.py b/lib/galaxy/tool_util/parser/xml.py index 749b3beb0fb8..ef39fd2ac2d4 100644 --- a/lib/galaxy/tool_util/parser/xml.py +++ b/lib/galaxy/tool_util/parser/xml.py @@ -21,6 +21,7 @@ DEFAULT_DELTA_FRAC, DEFAULT_EPS, DEFAULT_METRIC, + DEFAULT_PIN_LABELS, ) from galaxy.util import ( Element, @@ -793,6 +794,7 @@ def __parse_test_attributes(output_elem, attrib, parse_elements=False, parse_dis # Parameters for "image_diff" comparison attributes["metric"] = attrib.pop("metric", DEFAULT_METRIC) attributes["eps"] = float(attrib.pop("eps", DEFAULT_EPS)) + attributes["pin_labels"] = attrib.pop("pin_labels", DEFAULT_PIN_LABELS) if location and file is None: file = os.path.basename(location) # If no file specified, try to get filename from URL last component attributes["location"] = location diff --git a/lib/galaxy/tool_util/verify/__init__.py b/lib/galaxy/tool_util/verify/__init__.py index dd40e8cacf62..42b6998501de 100644 --- a/lib/galaxy/tool_util/verify/__init__.py +++ b/lib/galaxy/tool_util/verify/__init__.py @@ -42,6 +42,7 @@ DEFAULT_DELTA_FRAC, DEFAULT_EPS, DEFAULT_METRIC, + DEFAULT_PIN_LABELS, ) from galaxy.util import unicodify from galaxy.util.compression_utils import get_fileobj @@ -456,43 +457,87 @@ def files_contains(file1, file2, attributes=None): raise AssertionError(f"Failed to find '{contains}' in history data. (lines_diff={lines_diff}).") +def _singleobject_intersection_over_union( + mask1: "numpy.typing.NDArray", + mask2: "numpy.typing.NDArray", +) -> "numpy.floating": + return numpy.logical_and(mask1, mask2).sum() / numpy.logical_or(mask1, mask2).sum() + + def _multiobject_intersection_over_union( - mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", repeat_reverse: bool = True + mask1: "numpy.typing.NDArray", + mask2: "numpy.typing.NDArray", + pin_labels: Optional[List[int]] = None, + repeat_reverse: bool = True, ) -> List["numpy.floating"]: iou_list = [] for label1 in numpy.unique(mask1): cc1 = mask1 == label1 - cc1_iou_list = [] - for label2 in numpy.unique(mask2[cc1]): - cc2 = mask2 == label2 - cc1_iou_list.append(intersection_over_union(cc1, cc2)) - iou_list.append(max(cc1_iou_list)) + + # If the label is in `pin_labels`, then use the same label value to find the corresponding object in the second mask. + if pin_labels is not None and label1 in pin_labels: + cc2 = mask2 == label1 + iou_list.append(_singleobject_intersection_over_union(cc1, cc2)) + + # Otherwise, use the object with the largest IoU value, excluding the pinned labels. + else: + cc1_iou_list = [] + for label2 in numpy.unique(mask2[cc1]): + if pin_labels is not None and label2 in pin_labels: + continue + cc2 = mask2 == label2 + cc1_iou_list.append(_singleobject_intersection_over_union(cc1, cc2)) + iou_list.append(max(cc1_iou_list)) + if repeat_reverse: - iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False)) + iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, pin_labels, repeat_reverse=False)) + return iou_list -def intersection_over_union(mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray") -> "numpy.floating": +def intersection_over_union( + mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", pin_labels: Optional[List[int]] = None +) -> "numpy.floating": + """Compute the intersection over union (IoU) for the objects in two masks containing lables. + + The IoU is computed for each uniquely labeled image region (object), and the overall minimum value is returned (i.e. the worst value). + To compute the IoU for each object, the corresponding object in the other mask needs to be determined. + The object correspondences are not necessarily symmetric. + + By default, the corresponding object in the other mask is determined as the one with the largest IoU value. + If the label of an object is listed in `pin_labels`, then the corresponding object in the other mask is determined as the object with the same label value. + Objects with labels listed in `pin_labels` also cannot correspond to objects with different labels. + This is particularly useful when specific image regions must always be labeled with a designated label value (e.g., the image background is often labeled with 0 or -1). + """ assert mask1.dtype == mask2.dtype assert mask1.ndim == mask2.ndim == 2 assert mask1.shape == mask2.shape - if mask1.dtype == bool: - return numpy.logical_and(mask1, mask2).sum() / numpy.logical_or(mask1, mask2).sum() + for label in pin_labels or []: + count = sum(label in mask for mask in (mask1, mask2)) + count_str = {1: "one", 2: "both"} + assert count == 2, f"Label {label} is pinned but missing in {count_str[2 - count]} of the images." + return min(_multiobject_intersection_over_union(mask1, mask2, pin_labels)) + + +def _parse_label_list(label_list_str: Optional[str]) -> List[int]: + if label_list_str is None: + return [] else: - return min(_multiobject_intersection_over_union(mask1, mask2)) + return [int(label.strip()) for label in label_list_str.split(",") if len(label_list_str) > 0] def get_image_metric( attributes: Dict[str, Any] ) -> Callable[["numpy.typing.NDArray", "numpy.typing.NDArray"], "numpy.floating"]: metric_name = attributes.get("metric", DEFAULT_METRIC) + pin_labels = _parse_label_list(attributes.get("pin_labels", DEFAULT_PIN_LABELS)) metrics = { "mae": lambda arr1, arr2: numpy.abs(arr1 - arr2).mean(), # Convert to float before squaring to prevent overflows "mse": lambda arr1, arr2: numpy.square((arr1 - arr2).astype(float)).mean(), "rms": lambda arr1, arr2: math.sqrt(numpy.square((arr1 - arr2).astype(float)).mean()), "fro": lambda arr1, arr2: numpy.linalg.norm((arr1 - arr2).reshape(1, -1), "fro"), - "iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2), + "iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2, pin_labels), } try: return metrics[metric_name] diff --git a/lib/galaxy/tool_util/xsd/galaxy.xsd b/lib/galaxy/tool_util/xsd/galaxy.xsd index 59541f8b18f4..c4bde2f47a22 100644 --- a/lib/galaxy/tool_util/xsd/galaxy.xsd +++ b/lib/galaxy/tool_util/xsd/galaxy.xsd @@ -1825,6 +1825,11 @@ If you specify a `checksum`, it will be also used to check the integrity of the If ``compare`` is set to ``image_diff``, this is the maximum allowed distance between the data set that is generated in the test and the file in ``test-data/`` that is referenced by the ``file`` attribute, with distances computed with respect to the specified ``metric``. Default value is 0.01. + + + If ``compare`` is set to ``image_diff`` and ``metric`` is set to ``iou``, by default, object correspondances are established by maximizing the pairwise intersection over the union. If, however, the label of an object is listed in ``pin_labels``, then the corresponding object is determined according to the same label value (and that object cannot be the corresponding object of any other object with a different label). + + @@ -7788,7 +7793,7 @@ favour of a ``has_size`` assertion. - If ``compare`` is set to ``image_diff``, this is the metric used to compute the distance between images for quantification of their difference. For intensity images, possible metrics are *mean absolute error* (``mae``, the default), *mean squared error* (``mse``), *root mean squared* error (``rms``), and the *Frobenius norm* (``fro``). In addition, for binary images and label maps (with multiple objects), ``iou`` can be used to compute *one minus* the *intersection over the union* (IoU). Object correspondances are established by taking the pair of objects, for which the IoU is highest, and the distance of the images is the worst value determined for any pair of corresponding objects. + If ``compare`` is set to ``image_diff``, this is the metric used to compute the distance between images for quantification of their difference. For intensity images, possible metrics are *mean absolute error* (``mae``, the default), *mean squared error* (``mse``), *root mean squared* error (``rms``), and the *Frobenius norm* (``fro``). In addition, for binary images and label maps (with multiple objects), ``iou`` can be used to compute *one minus* the *intersection over the union* (IoU). Object correspondances are established by taking the pair of objects, for which the IoU is highest (also see the ``pin_labels`` attribute), and the distance of the images is the worst value determined for any pair of corresponding objects. diff --git a/test/functional/tools/image_diff.xml b/test/functional/tools/image_diff.xml index adb17639c603..addb4ef24b44 100644 --- a/test/functional/tools/image_diff.xml +++ b/test/functional/tools/image_diff.xml @@ -32,6 +32,14 @@ + + + + + + + + diff --git a/test/unit/tool_util/test_verify.py b/test/unit/tool_util/test_verify.py index db6955e08d69..fa814401527f 100644 --- a/test/unit/tool_util/test_verify.py +++ b/test/unit/tool_util/test_verify.py @@ -79,9 +79,9 @@ def _encode_image(im, **kwargs): F9 = _encode_image( numpy.array( [ - [0, 0, 0], - [0, 1, 0], - [0, 1, 2], + [200, 200, 200], + [200, 1, 200], + [200, 1, 2], ], dtype=numpy.uint8, ), @@ -179,6 +179,14 @@ def generate_tests_image_diff(): (f6, f7, {"metric": "fro", "eps": 100 - 1e-4}, AssertionError), (f6, f9, {"metric": "iou", "eps": (1 - 1 / 8) + 1e-4}, None), (f6, f9, {"metric": "iou", "eps": (1 - 1 / 8) - 1e-4}, AssertionError), + # tests `pin_labels` with a label not present in any image + (f6, f9, {"metric": "iou", "eps": 0.999999, "pin_labels": "5"}, AssertionError), + # tests `pin_labels` with a label present in both images + (f6, f9, {"metric": "iou", "eps": 0.999999, "pin_labels": "200"}, AssertionError), + (f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200"}, None), + # tests `pin_labels` with a label only present in one image + (f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200, 1"}, AssertionError), + (f6, f9, {"metric": "iou", "eps": 1.0, "pin_labels": "200, 255"}, AssertionError), ] return tests