Skip to content

Commit

Permalink
Add random()
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderswerdlow committed Jun 18, 2024
1 parent be794c7 commit e6a93ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/image_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def get_images(path: Path, recursive: bool = False, allowed_extensions=[".png",
yield Im.open(file)


def load_cached_from_url(url: str) -> BytesIO:
def load_cached_from_url(url: str, cache: bool = True) -> BytesIO:
import hashlib

cache_dir = Path.home() / ".cache" / "image_utils"
cache_dir.mkdir(parents=True, exist_ok=True)
filename = hashlib.md5(url.encode()).hexdigest()
local_path = cache_dir / filename

if local_path.exists():
if cache and local_path.exists():
return BytesIO(local_path.read_bytes())
else:
image_bytesio = download_file_bytes(url)
Expand Down
12 changes: 10 additions & 2 deletions src/image_utils/im.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jaxtyping import Bool, Float, Integer
from numpy import ndarray
from PIL import Image
from torch import rand

if importlib.util.find_spec("torch") is not None:
import torch
Expand Down Expand Up @@ -111,10 +112,10 @@ class Im:
default_normalize_mean = [0.4265, 0.4489, 0.4769]
default_normalize_std = [0.2053, 0.2206, 0.2578]

def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray], channel_range: Optional[ChannelRange] = None, **kwargs):
def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray, str, Path], channel_range: Optional[ChannelRange] = None, **kwargs):
# TODO: Add real URL checking here
if isinstance(arr, (str, Path)) and Path(arr).exists():
arr = Im.open(arr) # type: ignore
arr = Im.open(arr)
elif isinstance(arr, str):
arr = Image.open(load_cached_from_url(arr))

Expand Down Expand Up @@ -365,6 +366,13 @@ def open(filepath: Path, use_imageio=False) -> Im:
@callable_staticmethod
def new(h: int, w: int, color=(255, 255, 255)):
return Im(Image.new("RGB", (w, h), color))

@callable_staticmethod
def random(h: int = 1080, w: int = 1920) -> Im:
try:
return Im(Image.open(load_cached_from_url(f"https://unsplash.it/{w}/{h}?random", cache=False)))
except:
return Im(Image.open(load_cached_from_url(f"https://picsum.photos/{w}/{h}?random", cache=False)))

@_convert_to_datatype(desired_datatype=Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT)
def resize(self, height: int, width: int, resampling_mode: str = "bilinear"):
Expand Down

0 comments on commit e6a93ae

Please sign in to comment.