Skip to content

Commit

Permalink
Added type hints to ImageFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
radarhere committed Jun 21, 2024
1 parent 4b258be commit 324e548
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 39 deletions.
4 changes: 2 additions & 2 deletions Tests/test_color_lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,10 @@ def test_overflow(self) -> None:
class TestColorLut3DFilter:
def test_wrong_args(self) -> None:
with pytest.raises(ValueError, match="should be either an integer"):
ImageFilter.Color3DLUT("small", [1])
ImageFilter.Color3DLUT("small", [1]) # type: ignore[arg-type]

with pytest.raises(ValueError, match="should be either an integer"):
ImageFilter.Color3DLUT((11, 11), [1])
ImageFilter.Color3DLUT((11, 11), [1]) # type: ignore[arg-type]

with pytest.raises(ValueError, match=r"in \[2, 65\] range"):
ImageFilter.Color3DLUT((11, 11, 1), [1])
Expand Down
2 changes: 1 addition & 1 deletion Tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_builtinfilter_p() -> None:
builtin_filter = ImageFilter.BuiltinFilter()

with pytest.raises(ValueError):
builtin_filter.filter(hopper("P"))
builtin_filter.filter(hopper("P").im)


def test_kernel_not_enough_coefficients() -> None:
Expand Down
18 changes: 7 additions & 11 deletions Tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import pytest

from PIL import Image
from PIL import Image, _typing

from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature

if TYPE_CHECKING:
import numpy
import numpy.typing
import numpy.typing as npt
else:
numpy = pytest.importorskip("numpy", reason="NumPy not installed")

TEST_IMAGE_SIZE = (10, 10)


def test_numpy_to_image() -> None:
def to_image(
dtype: numpy.typing.DTypeLike, bands: int = 1, boolean: int = 0
) -> Image.Image:
def to_image(dtype: npt.DTypeLike, bands: int = 1, boolean: int = 0) -> Image.Image:
if bands == 1:
if boolean:
data = [0, 255] * 50
Expand Down Expand Up @@ -106,9 +104,7 @@ def test_1d_array() -> None:
assert_image(Image.fromarray(a), "L", (1, 5))


def _test_img_equals_nparray(
img: Image.Image, np_img: numpy.typing.NDArray[Any]
) -> None:
def _test_img_equals_nparray(img: Image.Image, np_img: _typing.NumpyArray) -> None:
assert len(np_img.shape) >= 2
np_size = np_img.shape[1], np_img.shape[0]
assert img.size == np_size
Expand Down Expand Up @@ -166,7 +162,7 @@ def test_save_tiff_uint16() -> None:
("HSV", numpy.uint8),
),
)
def test_to_array(mode: str, dtype: numpy.typing.DTypeLike) -> None:
def test_to_array(mode: str, dtype: npt.DTypeLike) -> None:
img = hopper(mode)

# Resize to non-square
Expand Down Expand Up @@ -216,7 +212,7 @@ def test_putdata() -> None:
numpy.float64,
),
)
def test_roundtrip_eye(dtype: numpy.typing.DTypeLike) -> None:
def test_roundtrip_eye(dtype: npt.DTypeLike) -> None:
arr = numpy.eye(10, dtype=dtype)
numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr)))

Expand Down
4 changes: 4 additions & 0 deletions docs/reference/internal_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Internal Modules
Provides a convenient way to import type hints that are not available
on some Python versions.

.. py:class:: NumpyArray
Typing alias.

.. py:class:: StrOrBytesPath
Typing alias.
Expand Down
72 changes: 49 additions & 23 deletions src/PIL/ImageFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
import abc
import functools
from types import ModuleType
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence, cast

Check warning on line 22 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L22

Added line #L22 was not covered by tests

if TYPE_CHECKING:
from . import _imaging
from ._typing import NumpyArray


class Filter:
@abc.abstractmethod
def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
pass


Expand All @@ -33,7 +37,9 @@ class MultibandFilter(Filter):


class BuiltinFilter(MultibandFilter):
def filter(self, image):
filterargs: tuple[Any, ...]

Check warning on line 40 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L40

Added line #L40 was not covered by tests

def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 42 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L42

Added line #L42 was not covered by tests
if image.mode == "P":
msg = "cannot filter palette images"
raise ValueError(msg)
Expand Down Expand Up @@ -91,7 +97,7 @@ def __init__(self, size: int, rank: int) -> None:
self.size = size
self.rank = rank

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 100 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L100

Added line #L100 was not covered by tests
if image.mode == "P":
msg = "cannot filter palette images"
raise ValueError(msg)
Expand Down Expand Up @@ -158,7 +164,7 @@ class ModeFilter(Filter):
def __init__(self, size: int = 3) -> None:
self.size = size

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 167 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L167

Added line #L167 was not covered by tests
return image.modefilter(self.size)


Expand All @@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
def __init__(self, radius: float | Sequence[float] = 2) -> None:
self.radius = radius

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 185 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L185

Added line #L185 was not covered by tests
xy = self.radius
if not isinstance(xy, (tuple, list)):
if isinstance(xy, (int, float)):

Check warning on line 187 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L187

Added line #L187 was not covered by tests
xy = (xy, xy)
if xy == (0, 0):
return image.copy()
Expand Down Expand Up @@ -208,9 +214,9 @@ def __init__(self, radius: float | Sequence[float]) -> None:
raise ValueError(msg)
self.radius = radius

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 217 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L217

Added line #L217 was not covered by tests
xy = self.radius
if not isinstance(xy, (tuple, list)):
if isinstance(xy, (int, float)):

Check warning on line 219 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L219

Added line #L219 was not covered by tests
xy = (xy, xy)
if xy == (0, 0):
return image.copy()
Expand Down Expand Up @@ -241,7 +247,7 @@ def __init__(
self.percent = percent
self.threshold = threshold

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 250 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L250

Added line #L250 was not covered by tests
return image.unsharp_mask(self.radius, self.percent, self.threshold)


Expand Down Expand Up @@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
name = "Color 3D LUT"

def __init__(
self, size, table, channels: int = 3, target_mode: str | None = None, **kwargs
):
self,
size: int | tuple[int, int, int],
table: Sequence[float] | Sequence[Sequence[int]] | NumpyArray,
channels: int = 3,
target_mode: str | None = None,
**kwargs: bool,
) -> None:
if channels not in (3, 4):
msg = "Only 3 or 4 output channels are supported"
raise ValueError(msg)
Expand All @@ -410,15 +421,16 @@ def __init__(
pass

if numpy and isinstance(table, numpy.ndarray):
numpy_table: NumpyArray = table

Check warning on line 424 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L424

Added line #L424 was not covered by tests
if copy_table:
table = table.copy()
numpy_table = numpy_table.copy()

Check warning on line 426 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L426

Added line #L426 was not covered by tests

if table.shape in [
if numpy_table.shape in [

Check warning on line 428 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L428

Added line #L428 was not covered by tests
(items * channels,),
(items, channels),
(size[2], size[1], size[0], channels),
]:
table = table.reshape(items * channels)
table = numpy_table.reshape(items * channels)

Check warning on line 433 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L433

Added line #L433 was not covered by tests
else:
wrong_size = True

Expand All @@ -428,15 +440,17 @@ def __init__(

# Convert to a flat list
if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table
raw_table = cast(Sequence[Sequence[int]], table)
flat_table: list[int] = []

Check warning on line 444 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L443-L444

Added lines #L443 - L444 were not covered by tests
for pixel in raw_table:
if len(pixel) != channels:
msg = (
"The elements of the table should "
f"have a length of {channels}."
)
raise ValueError(msg)
table.extend(pixel)
flat_table.extend(pixel)
table = flat_table

Check warning on line 453 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L452-L453

Added lines #L452 - L453 were not covered by tests

if wrong_size or len(table) != items * channels:
msg = (
Expand All @@ -449,23 +463,29 @@ def __init__(
self.table = table

@staticmethod
def _check_size(size: Any) -> list[int]:
def _check_size(size: Any) -> tuple[int, int, int]:

Check warning on line 466 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L466

Added line #L466 was not covered by tests
try:
_, _, _ = size
except ValueError as e:
msg = "Size should be either an integer or a tuple of three integers."
raise ValueError(msg) from e
except TypeError:
size = (size, size, size)
size = [int(x) for x in size]
size = tuple(int(x) for x in size)

Check warning on line 474 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L474

Added line #L474 was not covered by tests
for size_1d in size:
if not 2 <= size_1d <= 65:
msg = "Size should be in [2, 65] range."
raise ValueError(msg)
return size

@classmethod
def generate(cls, size, callback, channels=3, target_mode=None):
def generate(

Check warning on line 482 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L482

Added line #L482 was not covered by tests
cls,
size: int | tuple[int, int, int],
callback: Callable[[float, float, float], tuple[float, ...]],
channels: int = 3,
target_mode: str | None = None,
) -> Color3DLUT:
"""Generates new LUT using provided callback.
:param size: Size of the table. Passed to the constructor.
Expand All @@ -482,7 +502,7 @@ def generate(cls, size, callback, channels=3, target_mode=None):
msg = "Only 3 or 4 output channels are supported"
raise ValueError(msg)

table = [0] * (size_1d * size_2d * size_3d * channels)
table: list[float] = [0] * (size_1d * size_2d * size_3d * channels)

Check warning on line 505 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L505

Added line #L505 was not covered by tests
idx_out = 0
for b in range(size_3d):
for g in range(size_2d):
Expand All @@ -500,7 +520,13 @@ def generate(cls, size, callback, channels=3, target_mode=None):
_copy_table=False,
)

def transform(self, callback, with_normals=False, channels=None, target_mode=None):
def transform(

Check warning on line 523 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L523

Added line #L523 was not covered by tests
self,
callback: Callable[..., tuple[float, ...]],
with_normals: bool = False,
channels: int | None = None,
target_mode: str | None = None,
) -> Color3DLUT:
"""Transforms the table values using provided callback and returns
a new LUT with altered values.
Expand Down Expand Up @@ -564,7 +590,7 @@ def __repr__(self) -> str:
r.append(f"target_mode={self.mode}")
return "<{}>".format(" ".join(r))

def filter(self, image):
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:

Check warning on line 593 in src/PIL/ImageFilter.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/ImageFilter.py#L593

Added line #L593 was not covered by tests
from . import Image

return image.color_lut_3d(
Expand Down
10 changes: 8 additions & 2 deletions src/PIL/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

import os
import sys
from typing import Protocol, Sequence, TypeVar, Union
from typing import Any, Protocol, Sequence, TypeVar, Union

Check warning on line 5 in src/PIL/_typing.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/_typing.py#L5

Added line #L5 was not covered by tests

try:
import numpy.typing as npt

Check warning on line 8 in src/PIL/_typing.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/_typing.py#L7-L8

Added lines #L7 - L8 were not covered by tests

NumpyArray = npt.NDArray[Any]

Check warning on line 10 in src/PIL/_typing.py

View check run for this annotation

Codecov / codecov/patch

src/PIL/_typing.py#L10

Added line #L10 was not covered by tests
except ImportError:
pass

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
try:
from typing_extensions import TypeGuard
except ImportError:
from typing import Any

class TypeGuard: # type: ignore[no-redef]
def __class_getitem__(cls, item: Any) -> type[bool]:
Expand Down

0 comments on commit 324e548

Please sign in to comment.