From 3d3ad72dcff055d6b853b499e6fc953374397af0 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Thu, 29 Feb 2024 17:30:28 +0000 Subject: [PATCH] Add type annotations --- lib/galaxy/tool_util/verify/__init__.py | 51 +++++++++++++++---------- lib/galaxy/util/checkers.py | 6 +-- lib/galaxy/util/image_util.py | 24 ++++++------ 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/lib/galaxy/tool_util/verify/__init__.py b/lib/galaxy/tool_util/verify/__init__.py index ea3ade7d25c7..7adfa3321904 100644 --- a/lib/galaxy/tool_util/verify/__init__.py +++ b/lib/galaxy/tool_util/verify/__init__.py @@ -15,7 +15,9 @@ Any, Callable, Dict, + List, Optional, + TYPE_CHECKING, ) import numpy @@ -37,6 +39,9 @@ from .asserts import verify_assertions from .test_data import TestDataResolver +if TYPE_CHECKING: + import numpy.typing + log = logging.getLogger(__name__) DEFAULT_TEST_DATA_RESOLVER = TestDataResolver() @@ -442,21 +447,23 @@ def files_contains(file1, file2, attributes=None): raise AssertionError(f"Failed to find '{contains}' in history data. (lines_diff={lines_diff}).") -def _multiobject_intersection_over_union(mask1, mask2, repeat_reverse=True): - iou_list = list() +def _multiobject_intersection_over_union( + mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", repeat_reverse: bool = True +) -> List[numpy.floating]: + iou_list = [] for label1 in numpy.unique(mask1): cc1 = mask1 == label1 - cc1_iou_list = list() + cc1_iou_list = [] for label2 in numpy.unique(mask2[cc1]): cc2 = mask2 == label2 cc1_iou_list.append(intersection_over_union(cc1, cc2)) iou_list.append(max(cc1_iou_list)) if repeat_reverse: - iou_list += _multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False) + iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False)) return iou_list -def intersection_over_union(mask1, mask2): +def intersection_over_union(mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray") -> numpy.floating: assert mask1.dtype == mask2.dtype assert mask1.ndim == mask2.ndim == 2 assert mask1.shape == mask2.shape @@ -466,15 +473,17 @@ def intersection_over_union(mask1, mask2): return min(_multiobject_intersection_over_union(mask1, mask2)) -def get_image_metric(attributes): +def get_image_metric( + attributes: Dict[str, Any] +) -> Callable[["numpy.typing.NDArray", "numpy.typing.NDArray"], numpy.floating]: metric_name = attributes.get("metric", DEFAULT_METRIC) - attributes = attributes or {} metrics = { - "mae": lambda im1, im2: numpy.abs(im1 - im2).mean(), - "mse": lambda im1, im2: numpy.square((im1 - im2).astype(float)).mean(), - "rms": lambda im1, im2: math.sqrt(numpy.square((im1 - im2).astype(float)).mean()), - "fro": lambda im1, im2: numpy.linalg.norm((im1 - im2).reshape(1, -1), "fro"), - "iou": lambda im1, im2: 1 - intersection_over_union(im1, im2), + "mae": lambda arr1, arr2: numpy.abs(arr1 - arr2).mean(), + # Convert to float before squaring to prevent overflows + "mse": lambda arr1, arr2: numpy.square((arr1 - arr2).astype(float)).mean(), + "rms": lambda arr1, arr2: math.sqrt(numpy.square((arr1 - arr2).astype(float)).mean()), + "fro": lambda arr1, arr2: numpy.linalg.norm((arr1 - arr2).reshape(1, -1), "fro"), + "iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2), } try: return metrics[metric_name] @@ -482,20 +491,22 @@ def get_image_metric(attributes): raise ValueError(f'No such metric: "{metric_name}"') -def files_image_diff(file1, file2, attributes=None): +def files_image_diff(file1: str, file2: str, attributes: Optional[Dict[str, Any]] = None) -> None: """Check the pixel data of 2 image files for differences.""" attributes = attributes or {} - im1 = numpy.array(Image.open(file1)) - im2 = numpy.array(Image.open(file2)) + with Image.open(file1) as im1: + arr1 = numpy.array(im1) + with Image.open(file2) as im2: + arr2 = numpy.array(im2) - if im1.dtype != im2.dtype: - raise AssertionError(f"Image data types did not match ({im1.dtype}, {im2.dtype}).") + if arr1.dtype != arr2.dtype: + raise AssertionError(f"Image data types did not match ({arr1.dtype}, {arr2.dtype}).") - if im1.shape != im2.shape: - raise AssertionError(f"Image dimensions did not match ({im1.shape}, {im2.shape}).") + if arr1.shape != arr2.shape: + raise AssertionError(f"Image dimensions did not match ({arr1.shape}, {arr2.shape}).") - distance = get_image_metric(attributes)(im1, im2) + distance = get_image_metric(attributes)(arr1, arr2) distance_eps = attributes.get("eps", DEFAULT_EPS) if distance > distance_eps: raise AssertionError(f"Image difference {distance} exceeds eps={distance_eps}.") diff --git a/lib/galaxy/util/checkers.py b/lib/galaxy/util/checkers.py index de2e149aa6e4..8e3862bcfe89 100644 --- a/lib/galaxy/util/checkers.py +++ b/lib/galaxy/util/checkers.py @@ -187,11 +187,9 @@ def iter_zip(file_path: str): yield (z.open(f), f) -def check_image(file_path: str): +def check_image(file_path: str) -> bool: """Simple wrapper around image_type to yield a True/False verdict""" - if image_type(file_path): - return True - return False + return bool(image_type(file_path)) COMPRESSION_CHECK_FUNCTIONS: Dict[str, CompressionChecker] = { diff --git a/lib/galaxy/util/image_util.py b/lib/galaxy/util/image_util.py index d24a75f10725..1b9bf7d99bb5 100644 --- a/lib/galaxy/util/image_util.py +++ b/lib/galaxy/util/image_util.py @@ -2,25 +2,25 @@ import imghdr import logging +from typing import ( + List, + Optional, +) try: - import Image as PIL + from PIL import Image except ImportError: - try: - from PIL import Image as PIL - except ImportError: - PIL = None + PIL = None log = logging.getLogger(__name__) -def image_type(filename): +def image_type(filename: str) -> Optional[str]: fmt = None - if PIL is not None: + if Image is not None: try: - im = PIL.open(filename) - fmt = im.format - im.close() + with Image.open(filename) as im: + fmt = im.format except Exception: # We continue to try with imghdr, so this is a rare case of an # exception we expect to happen frequently, so we're not logging @@ -30,10 +30,10 @@ def image_type(filename): if fmt: return fmt.upper() else: - return False + return None -def check_image_type(filename, types): +def check_image_type(filename: str, types: List[str]) -> bool: fmt = image_type(filename) if fmt in types: return True