From e6a93ae29a2eb7c50a49b97195c9bbb93b6d9719 Mon Sep 17 00:00:00 2001 From: Alexander Swerdlow Date: Tue, 18 Jun 2024 09:11:40 -0400 Subject: [PATCH] Add random() --- src/image_utils/file_utils.py | 4 ++-- src/image_utils/im.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/image_utils/file_utils.py b/src/image_utils/file_utils.py index 1163303..c1f0441 100644 --- a/src/image_utils/file_utils.py +++ b/src/image_utils/file_utils.py @@ -59,7 +59,7 @@ def get_images(path: Path, recursive: bool = False, allowed_extensions=[".png", yield Im.open(file) -def load_cached_from_url(url: str) -> BytesIO: +def load_cached_from_url(url: str, cache: bool = True) -> BytesIO: import hashlib cache_dir = Path.home() / ".cache" / "image_utils" @@ -67,7 +67,7 @@ def load_cached_from_url(url: str) -> BytesIO: filename = hashlib.md5(url.encode()).hexdigest() local_path = cache_dir / filename - if local_path.exists(): + if cache and local_path.exists(): return BytesIO(local_path.read_bytes()) else: image_bytesio = download_file_bytes(url) diff --git a/src/image_utils/im.py b/src/image_utils/im.py index 2723f21..8a94654 100644 --- a/src/image_utils/im.py +++ b/src/image_utils/im.py @@ -18,6 +18,7 @@ from jaxtyping import Bool, Float, Integer from numpy import ndarray from PIL import Image +from torch import rand if importlib.util.find_spec("torch") is not None: import torch @@ -111,10 +112,10 @@ class Im: default_normalize_mean = [0.4265, 0.4489, 0.4769] default_normalize_std = [0.2053, 0.2206, 0.2578] - def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray], channel_range: Optional[ChannelRange] = None, **kwargs): + def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], channel_range: Optional[ChannelRange] = None, **kwargs): # TODO: Add real URL checking here if isinstance(arr, (str, Path)) and Path(arr).exists(): - arr = Im.open(arr) # type: ignore + arr = Im.open(arr) elif isinstance(arr, str): arr = Image.open(load_cached_from_url(arr)) @@ -365,6 +366,13 @@ def open(filepath: Path, use_imageio=False) -> Im: @callable_staticmethod def new(h: int, w: int, color=(255, 255, 255)): return Im(Image.new("RGB", (w, h), color)) + + @callable_staticmethod + def random(h: int = 1080, w: int = 1920) -> Im: + try: + return Im(Image.open(load_cached_from_url(f"https://unsplash.it/{w}/{h}?random", cache=False))) + except: + return Im(Image.open(load_cached_from_url(f"https://picsum.photos/{w}/{h}?random", cache=False))) @_convert_to_datatype(desired_datatype=Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT) def resize(self, height: int, width: int, resampling_mode: str = "bilinear"):