Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderswerdlow committed Jun 9, 2024
1 parent 511d72c commit be794c7
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 45 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ venv/
__pycache__
*.egg-info
docs/_build/
docs/generated/
docs/generated/
dist/
22 changes: 12 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "image_utils"
version = "0.0.1"
name = "image_utilities"
version = "0.0.2.dev1"
authors = [{ name="Alexander Swerdlow", email="aswerdlow1@gmail.com" }]
description = "A utility library for common image and video operations."
readme = "README.md"
Expand All @@ -16,24 +16,26 @@ classifiers = [
]

dependencies = [
"torch",
"torchvision",
"opencv-python>=4.5",
"numpy>=1.22",
"pillow>=9.0.0",
"einops>=0.6.0",
"numpy>=1.17",
"pillow>=8.0.0",
"einops>=0.3.0",
"jaxtyping>=0.2.19"
]

[project.optional-dependencies]
TORCH = ["torch", "torchvision"]
VIDEO = ["av>=10.0.0", "imageio[ffmpeg]>=2.23.0"]
PROC = ["joblib>=1.2.0"]
DEV = ["pytest", "pytest-cov", "lovely-tensors>=0.1.14", "lovely-numpy>=0.2.8"]
ALL = ["image_utils[VIDEO,PROC,DEV]"]
ALL = ["image_utils[TORCH,VIDEO,PROC,DEV]"]

[project.urls]
Homepage = "https://github.com/alexanderswerdlow/image_utils"

[tool.black]
line-length = 150
target-version = ['py310']
target-version = ['py310']

[tool.hatch.build.targets.wheel]
packages = ["src/image_utils", "src/im"]

Empty file added src/__init__.py
Empty file.
1 change: 1 addition & 0 deletions src/im/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from image_utils import *
4 changes: 4 additions & 0 deletions src/image_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .im import *
from .standalone_image_utils import *
from .file_utils import *

def disable():
from lovely_numpy import lovely, set_config
set_config(repr=None)
4 changes: 2 additions & 2 deletions src/image_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def delete_create_folder(path: Path):
path.mkdir(parents=True, exist_ok=True)


def get_rand_hex():
def get_rand_hex() -> str:
return "".join(random.choices(string.ascii_uppercase + string.digits, k=5))


def get_date_time_str():
def get_date_time_str() -> str:
return datetime.now().strftime("%Y_%m_%d-%H_%M")


Expand Down
80 changes: 63 additions & 17 deletions src/image_utils/im.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
from io import BytesIO
from math import ceil
from pathlib import Path
from typing import Callable, Iterable, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Tuple, Type, Union

import numpy as np
import torch
from einops import rearrange, repeat
from jaxtyping import Bool, Float, Integer
from numpy import ndarray
from PIL import Image
from torch import Tensor

if importlib.util.find_spec("torch") is not None:
import torch
from torch import Tensor
else:
class Tensor:
pass

if importlib.util.find_spec("image_utils") is not None:
from image_utils.file_utils import get_date_time_str, load_cached_from_url
Expand All @@ -28,11 +33,12 @@
if importlib.util.find_spec("imageio") is not None:
from imageio import v3 as iio

colorize_weights = {}
ImArr = Union[ndarray, Tensor] # The actual array itself
ImArrType = Type[Union[ndarray, Tensor]] # The object itself is just a type
ImDtype = Union[torch.dtype, np.dtype]
if TYPE_CHECKING:
ImArr = Union[ndarray, Tensor] # The actual array itself
ImArrType = Type[Union[ndarray, Tensor]] # The object itself is just a type
ImDtype = Union[torch.dtype, np.dtype]

colorize_weights = {}
enable_warnings = os.getenv("IMAGE_UTILS_DISABLE_WARNINGS") is None

class callable_staticmethod(staticmethod):
Expand Down Expand Up @@ -130,7 +136,7 @@ def __init__(self, arr: Union["Im", Tensor, Image.Image, ndarray], channel_range
self.arr: ImArr = arr
if isinstance(self.arr, ndarray):
self.arr_type = ndarray
self.device = torch.device("cpu")
self.device = "cpu"
elif isinstance(self.arr, Tensor):
self.device = self.arr.device
self.arr_type = Tensor
Expand Down Expand Up @@ -253,7 +259,7 @@ def _handle_order_transform(self, im, desired_order: ChannelOrder, desired_range
if self.channel_range != desired_range:
assert is_ndarray(im) or is_tensor(im)
if self.channel_range == ChannelRange.FLOAT and desired_range == ChannelRange.UINT8:
if self.channels == 1:
if self.channels == 1 and im.max() > im.min():
im = (im - im.min()) / (im.max() - im.min())
im = im * 255
if self.channels == 1:
Expand Down Expand Up @@ -385,7 +391,7 @@ def scale_to_height(self, new_height: int, **kwargs) -> Im:
return self.resize(new_height, wsize, **kwargs)

@callable_staticmethod
def _save_data(filepath: Path = Path(get_date_time_str()), filetype: str = "png"):
def _save_data(filepath: Path = Path(get_date_time_str()), filetype: str = "png") -> Path:
filepath = Path(filepath)
if filepath.suffix == "":
filepath = filepath.with_suffix(f".{filetype}")
Expand All @@ -404,7 +410,14 @@ def grid(self, **kwargs) -> Im:
img = utils.make_grid(self.arr, **kwargs) # type: ignore
return Im(img)

def save(self, filepath: Optional[Path] = None, filetype: str = "png", optimize: bool = False, quality: Optional[float] = None, **kwargs):
def save(
self,
filepath: Optional[Path] = None,
filetype: str = "png",
optimize: bool = False,
quality: Optional[float] = None,
**kwargs
) -> Path:
if filepath is None:
filepath = Path(get_date_time_str())

Expand All @@ -421,6 +434,8 @@ def save(self, filepath: Optional[Path] = None, filetype: str = "png", optimize:

img.save(filepath, **flags)

return filepath.resolve()

@_convert_to_datatype(desired_datatype=ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
def write_text(
self,
Expand Down Expand Up @@ -546,13 +561,21 @@ def concat_horizontal(*args, **kwargs) -> Im:
"""Concatenates images horizontally (i.e. left to right)"""
return concat_variable(concat_horizontal_, *args, **kwargs)

def save_video(self, filepath: Optional[Path] = None, fps: int = 4, format="mp4"):
def save_video(self, filepath: Optional[Path] = None, fps: int = 4, format="mp4", use_pyav: bool = False):
if filepath is None:
filepath = Path(get_date_time_str())
filepath = Im._save_data(filepath, format)
byte_stream = self.encode_video(fps, format)
with open(filepath, "wb") as f:
f.write(byte_stream.getvalue())

filepath: Path = Im._save_data(filepath, format)

if use_pyav:
from image_utils.video_utils import write_video
self = self._convert(desired_datatype=ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
assert isinstance(self.arr, ndarray)
write_video(self.arr, filepath, fps=fps)
else:
byte_stream = self.encode_video(fps, format)
with open(filepath, "wb") as f:
f.write(byte_stream.getvalue())

@_convert_to_datatype(desired_datatype=ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
def encode_video(self, fps: int, format="mp4") -> BytesIO:
Expand Down Expand Up @@ -621,6 +644,29 @@ def pca(self, **kwargs) -> Im:
output: Tensor = rearrange(output, "(b h w) c -> b h w c", b=b, h=h, w=w)
return Im(output)

def show(self):
import subprocess

method = None
if subprocess.run(['which', 'imgcat'], capture_output=True).returncode == 0:
method = 'iterm2-imgcat'
elif subprocess.run(['which', 'xdg-open'], capture_output=True).returncode == 0:
method = 'xdg-open'

if method is not None:
with tempfile.TemporaryDirectory() as temp_dir:
filename = self.save(Path(temp_dir))
if method == 'iterm2-imgcat':
print('\n' * 4)
print('\033[4F')
subprocess.check_call(['imgcat', filename])
print('\033[4B')
else:
subprocess.check_call(['xdg-open', filename])
else:
filename = self.save()
print(f'Failed to view image.Image saved to {filename}')

@_convert_to_datatype(desired_datatype=Tensor, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
def bool_to_rgb(self) -> Im:
return self
Expand Down Expand Up @@ -788,4 +834,4 @@ def concat_vertical_(im1: Im, im2: Im, spacing: int = 0, **kwargs) -> Im:
new_im2_arr[..., spacing:, :, :] = im2_arr
im2_arr = new_im2_arr

return Im(concat_along_dim(im1_arr, im2_arr, dim=-3))
return Im(concat_along_dim(im1_arr, im2_arr, dim=-3))
29 changes: 17 additions & 12 deletions src/image_utils/library_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
try:
import torch
import numpy as np
def enable():
try:
import importlib
if importlib.util.find_spec("numpy") is not None:
import numpy as np
np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120)
from lovely_numpy import lovely, set_config
set_config(repr=lovely)

if importlib.util.find_spec("torch") is not None:
import torch
torch.set_printoptions(sci_mode=False, precision=3, threshold=10, edgeitems=2, linewidth=120)
import lovely_tensors as lt
lt.monkey_patch()
except ImportError as e:
print("lovely_tensors is not installed. Run `pip install lovely-tensors` if you wish to use it.")

torch.set_printoptions(sci_mode=False, precision=3, threshold=10, edgeitems=2, linewidth=120)
np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120)
import lovely_tensors as lt
from lovely_numpy import lovely, set_config

lt.monkey_patch()
set_config(repr=lovely)
except ImportError:
print("lovely_tensors is not installed. Run `pip install lovely-tensors` if you wish to use it.")
enable()
13 changes: 10 additions & 3 deletions src/image_utils/standalone_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
from typing import Optional, Union

import numpy as np
import torch
import torchvision.transforms.functional as T
from PIL import Image
from torch import Tensor

import importlib

if importlib.util.find_spec("torch") is not None:
import torch
from torch import Tensor

if importlib.util.find_spec("torchvision") is not None:
import torchvision.transforms.functional as T



def torch_to_numpy(arr: Tensor):
Expand Down

0 comments on commit be794c7

Please sign in to comment.