Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderswerdlow committed Jun 23, 2024
1 parent e6a93ae commit 8993958
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
TORCH = ["torch", "torchvision"]
VIDEO = ["av>=10.0.0", "imageio[ffmpeg]>=2.23.0"]
PROC = ["joblib>=1.2.0"]
DEV = ["pytest", "pytest-cov", "lovely-tensors>=0.1.14", "lovely-numpy>=0.2.8"]
DEV = ["pytest", "pytest-cov", "lovely-tensors>=0.1.14", "lovely-numpy>=0.2.8", "opencv-python"]
ALL = ["image_utils[TORCH,VIDEO,PROC,DEV]"]

[project.urls]
Expand Down
43 changes: 25 additions & 18 deletions src/image_utils/im.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@
from io import BytesIO
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Tuple, Type, Union, Any

import numpy as np
from einops import rearrange, repeat
from jaxtyping import Bool, Float, Integer
from numpy import ndarray
from PIL import Image
from torch import rand


if importlib.util.find_spec("torch") is not None:
import torch
from torch import Tensor
from torch import Tensor, device # type: ignore
else:
class Tensor:
pass
class device:
def __init__(self, type: str):
self.type = type

if importlib.util.find_spec("image_utils") is not None:
from image_utils.file_utils import get_date_time_str, load_cached_from_url
Expand All @@ -35,6 +36,7 @@ class Tensor:
from imageio import v3 as iio

if TYPE_CHECKING:
from torch import Tensor
ImArr = Union[ndarray, Tensor] # The actual array itself
ImArrType = Type[Union[ndarray, Tensor]] # The object itself is just a type
ImDtype = Union[torch.dtype, np.dtype]
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch
setattr(self, attr, getattr(arr, attr))
return

self.device: torch.device
self.device: device
self.arr_type: ImArrType

# To handle things in a unified manner, we choose to always convert PIL Images -> NumPy internally
Expand All @@ -137,7 +139,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch
self.arr: ImArr = arr
if isinstance(self.arr, ndarray):
self.arr_type = ndarray
self.device = "cpu"
self.device = device("cpu")
elif isinstance(self.arr, Tensor):
self.device = self.arr.device
self.arr_type = Tensor
Expand All @@ -154,6 +156,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch

# We normalize all arrays to (B, H, W, C) and record their original shape so
# we can re-transform then when we need to output them
self.arr_transform: Callable[[ImArr], ImArr]
if len(self.arr.shape) == 2:
self.channel_order = ChannelOrder.HWC
self.arr_transform = partial(rearrange, pattern="() h w () -> h w")
Expand All @@ -165,7 +168,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], ch
extra_dims = self.arr.shape[:-3]
mapping = {k: v for k, v in zip(string.ascii_uppercase, extra_dims)}
transform_str = f'({" ".join(sorted(list(mapping.keys())))}) a b c -> {" ".join(sorted(list(mapping.keys())))} a b c'
self.arr_transform = partial(rearrange, pattern=transform_str, **mapping) # lambda x: rearrange(x, transform_str, g)
self.arr_transform = partial(rearrange, pattern=transform_str, **mapping) # type: ignore # lambda x: rearrange(x, transform_str, g)
else:
raise ValueError("Must be between 3-5 dims")

Expand Down Expand Up @@ -239,7 +242,7 @@ def wrapper(self: Im, *args, **kwargs):

return custom_decorator

def _handle_order_transform(self, im, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None):
def _handle_order_transform(self, im: ImArr, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None) -> ImArr:
if select_batch is not None:
im = im[select_batch]
else:
Expand Down Expand Up @@ -278,9 +281,9 @@ def _handle_order_transform(self, im, desired_order: ChannelOrder, desired_range
raise ValueError("Not supported")

if desired_range == ChannelRange.UINT8:
im = im.astype(np.uint8) if is_ndarray(im) else im.to(torch.uint8)
im = im.astype(np.uint8) if isinstance(im, ndarray) else im.to(torch.uint8)
elif desired_range == ChannelRange.FLOAT:
im = im.astype(np.float32) if is_ndarray(im) else im.to(torch.float32)
im = im.astype(np.float32) if isinstance(im, ndarray) else im.to(torch.float32)

return im

Expand All @@ -290,6 +293,7 @@ def get_np(self, order=ChannelOrder.HWC, range=ChannelRange.UINT8) -> ndarray:
arr = torch_to_numpy(arr) # type: ignore

arr = self._handle_order_transform(arr, order, range)
assert isinstance(arr, ndarray)
return arr

def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> Tensor:
Expand All @@ -298,6 +302,7 @@ def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> Tensor:
arr = torch.from_numpy(arr)

arr = self._handle_order_transform(arr, order, range)
assert isinstance(arr, Tensor)
if self.device is not None:
arr = arr.to(self.device)
return arr
Expand Down Expand Up @@ -376,8 +381,8 @@ def random(h: int = 1080, w: int = 1920) -> Im:

@_convert_to_datatype(desired_datatype=Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT)
def resize(self, height: int, width: int, resampling_mode: str = "bilinear"):
assert isinstance(self.arr, Tensor)
from torchvision.transforms.functional import resize, InterpolationMode
assert isinstance(self.arr, torch.Tensor)
arr = resize(self.arr, [height, width], InterpolationMode(resampling_mode), antialias=True)
arr = self.arr_transform(arr)
return Im(arr)
Expand Down Expand Up @@ -429,7 +434,7 @@ def save(
if filepath is None:
filepath = Path(get_date_time_str())

filepath = Im._save_data(filepath, filetype)
filepath: Path = Im._save_data(filepath, filetype)

if self.batch_size > 1:
img = self.get_torch()
Expand Down Expand Up @@ -538,7 +543,7 @@ def normalize(self, normalize_min_max: bool = False, **kwargs) -> Im:
self.arr = (self.arr - mean) / std
return self

def denormalize(self, clamp: Union[bool, tuple[float, float]] = (0, 1.0), **kwargs) -> Im:
def denormalize(self, clamp: tuple[float, float] = (0, 1.0), **kwargs) -> Im:
self, mean, std = self.normalize_setup(**kwargs)
self.arr = (self.arr * std) + mean
if isinstance(self.arr, ndarray):
Expand Down Expand Up @@ -595,10 +600,10 @@ def encode_video(self, fps: int, format="mp4") -> BytesIO:
import imageio
if format == "webm":
writer = imageio.get_writer(
ntp.name, format="webm", codec="libvpx-vp9", pixelformat="yuv420p", output_params=["-lossless", "1"], fps=fps
ntp.name, format="webm", codec="libvpx-vp9", pixelformat="yuv420p", output_params=["-lossless", "1"], fps=fps # type: ignore
)
elif format == "gif":
writer = imageio.get_writer(ntp.name, format="GIF", mode="I", duration=(1000 * 1 / fps))
writer = imageio.get_writer(ntp.name, format="GIF", mode="I", duration=(1000 * 1 / fps)) # type: ignore
elif format == "mp4":
writer = imageio.get_writer(ntp.name, quality=10, pixelformat="yuv420p", codec="libx264", fps=fps)
else:
Expand Down Expand Up @@ -728,7 +733,7 @@ def get_arr_hwc(im: Im):
return im._handle_order_transform(im.arr, desired_order=ChannelOrder.HWC, desired_range=im.channel_range)


def new_like(arr, shape, fill: Optional[tuple[int]] = None):
def new_like(arr, shape, fill: Optional[tuple[int]] = None) -> ImArr:
if is_ndarray(arr):
new_arr = np.zeros_like(arr, shape=shape)
elif is_tensor(arr):
Expand All @@ -748,14 +753,16 @@ def new_like(arr, shape, fill: Optional[tuple[int]] = None):

def concat_along_dim(arr_1: ImArr, arr_2: ImArr, dim: int):
if is_ndarray(arr_1) and is_ndarray(arr_2):
assert isinstance(arr_1, np.ndarray)
assert isinstance(arr_2, np.ndarray)
return np.concatenate((arr_1, arr_2), axis=dim)
elif is_tensor(arr_1) and is_tensor(arr_2):
return torch.cat([arr_1, arr_2], dim=dim) # type: ignore
else:
raise ValueError("Must be numpy array or torch tensor")


def broadcast_arrays(im1_arr, im2_arr):
def broadcast_arrays(im1_arr, im2_arr) -> Tuple[ImArr, ImArr]:
"""
Takes [..., H, W, C] and [..., H, W, C] and broadcasts them to the same shape.
Expand Down
7 changes: 5 additions & 2 deletions src/image_utils/library_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ def enable():
if importlib.util.find_spec("numpy") is not None:
import numpy as np
np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120)
from lovely_numpy import lovely, set_config
set_config(repr=lovely)
try:
from lovely_numpy import lovely, set_config
set_config(repr=lovely)
except:
print(f"Failed to enable lovely_numpy.")

if importlib.util.find_spec("torch") is not None:
import torch
Expand Down

0 comments on commit 8993958

Please sign in to comment.