diff --git a/pyproject.toml b/pyproject.toml index 850b06b..dc20fb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ TORCH = ["torch", "torchvision"] VIDEO = ["av>=10.0.0", "imageio[ffmpeg]>=2.23.0"] PROC = ["joblib>=1.2.0"] -DEV = ["pytest", "pytest-cov", "lovely-tensors>=0.1.14", "lovely-numpy>=0.2.8"] +DEV = ["pytest", "pytest-cov", "lovely-tensors>=0.1.14", "lovely-numpy>=0.2.8", "opencv-python"] ALL = ["image_utils[TORCH,VIDEO,PROC,DEV]"] [project.urls] diff --git a/src/image_utils/im.py b/src/image_utils/im.py index 8a94654..57ac707 100644 --- a/src/image_utils/im.py +++ b/src/image_utils/im.py @@ -11,21 +11,22 @@ from io import BytesIO from math import ceil from pathlib import Path -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Tuple, Type, Union, Any import numpy as np from einops import rearrange, repeat from jaxtyping import Bool, Float, Integer from numpy import ndarray from PIL import Image -from torch import rand + if importlib.util.find_spec("torch") is not None: import torch - from torch import Tensor + from torch import Tensor, device # type: ignore else: - class Tensor: - pass + class device: + def __init__(self, type: str): + self.type = type if importlib.util.find_spec("image_utils") is not None: from image_utils.file_utils import get_date_time_str, load_cached_from_url @@ -35,6 +36,7 @@ class Tensor: from imageio import v3 as iio if TYPE_CHECKING: + from torch import Tensor ImArr = Union[ndarray, Tensor] # The actual array itself ImArrType = Type[Union[ndarray, Tensor]] # The object itself is just a type ImDtype = Union[torch.dtype, np.dtype] @@ -126,7 +128,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch setattr(self, attr, getattr(arr, attr)) return - self.device: torch.device + self.device: device self.arr_type: ImArrType # To handle things in a unified manner, we choose to always convert PIL Images -> NumPy internally @@ -137,7 +139,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch self.arr: ImArr = arr if isinstance(self.arr, ndarray): self.arr_type = ndarray - self.device = "cpu" + self.device = device("cpu") elif isinstance(self.arr, Tensor): self.device = self.arr.device self.arr_type = Tensor @@ -154,6 +156,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch # We normalize all arrays to (B, H, W, C) and record their original shape so # we can re-transform then when we need to output them + self.arr_transform: Callable[[ImArr], ImArr] if len(self.arr.shape) == 2: self.channel_order = ChannelOrder.HWC self.arr_transform = partial(rearrange, pattern="() h w () -> h w") @@ -165,7 +168,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch extra_dims = self.arr.shape[:-3] mapping = {k: v for k, v in zip(string.ascii_uppercase, extra_dims)} transform_str = f'({" ".join(sorted(list(mapping.keys())))}) a b c -> {" ".join(sorted(list(mapping.keys())))} a b c' - self.arr_transform = partial(rearrange, pattern=transform_str, **mapping) # lambda x: rearrange(x, transform_str, g) + self.arr_transform = partial(rearrange, pattern=transform_str, **mapping) # type: ignore # lambda x: rearrange(x, transform_str, g) else: raise ValueError("Must be between 3-5 dims") @@ -239,7 +242,7 @@ def wrapper(self: Im, *args, **kwargs): return custom_decorator - def _handle_order_transform(self, im, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None): + def _handle_order_transform(self, im: ImArr, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None) -> ImArr: if select_batch is not None: im = im[select_batch] else: @@ -278,9 +281,9 @@ def _handle_order_transform(self, im, desired_order: ChannelOrder, desired_range raise ValueError("Not supported") if desired_range == ChannelRange.UINT8: - im = im.astype(np.uint8) if is_ndarray(im) else im.to(torch.uint8) + im = im.astype(np.uint8) if isinstance(im, ndarray) else im.to(torch.uint8) elif desired_range == ChannelRange.FLOAT: - im = im.astype(np.float32) if is_ndarray(im) else im.to(torch.float32) + im = im.astype(np.float32) if isinstance(im, ndarray) else im.to(torch.float32) return im @@ -290,6 +293,7 @@ def get_np(self, order=ChannelOrder.HWC, range=ChannelRange.UINT8) -> ndarray: arr = torch_to_numpy(arr) # type: ignore arr = self._handle_order_transform(arr, order, range) + assert isinstance(arr, ndarray) return arr def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> Tensor: @@ -298,6 +302,7 @@ def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> Tensor: arr = torch.from_numpy(arr) arr = self._handle_order_transform(arr, order, range) + assert isinstance(arr, Tensor) if self.device is not None: arr = arr.to(self.device) return arr @@ -376,8 +381,8 @@ def random(h: int = 1080, w: int = 1920) -> Im: @_convert_to_datatype(desired_datatype=Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT) def resize(self, height: int, width: int, resampling_mode: str = "bilinear"): - assert isinstance(self.arr, Tensor) from torchvision.transforms.functional import resize, InterpolationMode + assert isinstance(self.arr, torch.Tensor) arr = resize(self.arr, [height, width], InterpolationMode(resampling_mode), antialias=True) arr = self.arr_transform(arr) return Im(arr) @@ -429,7 +434,7 @@ def save( if filepath is None: filepath = Path(get_date_time_str()) - filepath = Im._save_data(filepath, filetype) + filepath: Path = Im._save_data(filepath, filetype) if self.batch_size > 1: img = self.get_torch() @@ -538,7 +543,7 @@ def normalize(self, normalize_min_max: bool = False, **kwargs) -> Im: self.arr = (self.arr - mean) / std return self - def denormalize(self, clamp: Union[bool, tuple[float, float]] = (0, 1.0), **kwargs) -> Im: + def denormalize(self, clamp: tuple[float, float] = (0, 1.0), **kwargs) -> Im: self, mean, std = self.normalize_setup(**kwargs) self.arr = (self.arr * std) + mean if isinstance(self.arr, ndarray): @@ -595,10 +600,10 @@ def encode_video(self, fps: int, format="mp4") -> BytesIO: import imageio if format == "webm": writer = imageio.get_writer( - ntp.name, format="webm", codec="libvpx-vp9", pixelformat="yuv420p", output_params=["-lossless", "1"], fps=fps + ntp.name, format="webm", codec="libvpx-vp9", pixelformat="yuv420p", output_params=["-lossless", "1"], fps=fps # type: ignore ) elif format == "gif": - writer = imageio.get_writer(ntp.name, format="GIF", mode="I", duration=(1000 * 1 / fps)) + writer = imageio.get_writer(ntp.name, format="GIF", mode="I", duration=(1000 * 1 / fps)) # type: ignore elif format == "mp4": writer = imageio.get_writer(ntp.name, quality=10, pixelformat="yuv420p", codec="libx264", fps=fps) else: @@ -728,7 +733,7 @@ def get_arr_hwc(im: Im): return im._handle_order_transform(im.arr, desired_order=ChannelOrder.HWC, desired_range=im.channel_range) -def new_like(arr, shape, fill: Optional[tuple[int]] = None): +def new_like(arr, shape, fill: Optional[tuple[int]] = None) -> ImArr: if is_ndarray(arr): new_arr = np.zeros_like(arr, shape=shape) elif is_tensor(arr): @@ -748,6 +753,8 @@ def new_like(arr, shape, fill: Optional[tuple[int]] = None): def concat_along_dim(arr_1: ImArr, arr_2: ImArr, dim: int): if is_ndarray(arr_1) and is_ndarray(arr_2): + assert isinstance(arr_1, np.ndarray) + assert isinstance(arr_2, np.ndarray) return np.concatenate((arr_1, arr_2), axis=dim) elif is_tensor(arr_1) and is_tensor(arr_2): return torch.cat([arr_1, arr_2], dim=dim) # type: ignore @@ -755,7 +762,7 @@ def concat_along_dim(arr_1: ImArr, arr_2: ImArr, dim: int): raise ValueError("Must be numpy array or torch tensor") -def broadcast_arrays(im1_arr, im2_arr): +def broadcast_arrays(im1_arr, im2_arr) -> Tuple[ImArr, ImArr]: """ Takes [..., H, W, C] and [..., H, W, C] and broadcasts them to the same shape. diff --git a/src/image_utils/library_ops.py b/src/image_utils/library_ops.py index 720564c..0691711 100644 --- a/src/image_utils/library_ops.py +++ b/src/image_utils/library_ops.py @@ -4,8 +4,11 @@ def enable(): if importlib.util.find_spec("numpy") is not None: import numpy as np np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120) - from lovely_numpy import lovely, set_config - set_config(repr=lovely) + try: + from lovely_numpy import lovely, set_config + set_config(repr=lovely) + except: + print(f"Failed to enable lovely_numpy.") if importlib.util.find_spec("torch") is not None: import torch