Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Classification And Detection Scene #259

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/python/model_api/models/result/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def label_names(self, value):

@property
def saliency_map(self):
"""Saliency map for XAI.

Returns:
np.ndarray: Saliency map in dim of (B, N_CLASSES, H, W).
"""
return self._saliency_map

@saliency_map.setter
Expand Down
4 changes: 4 additions & 0 deletions src/python/model_api/visualizer/layout/hstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import PIL

from model_api.visualizer.primitive import Overlay

from .layout import Layout

if TYPE_CHECKING:
Expand All @@ -31,6 +33,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
images = []
for _primitive in scene.get_primitives(primitive):
image_ = _primitive.compute(image.copy())
if isinstance(_primitive, Overlay):
image_ = Overlay.overlay_labels(image=image_, labels=_primitive.label)
images.append(image_)
return self._stitch(*images)
return None
Expand Down
31 changes: 30 additions & 1 deletion src/python/model_api/visualizer/primitive/overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from __future__ import annotations

from typing import Union

import numpy as np
import PIL
from PIL import ImageFont

from .primitive import Primitive

Expand All @@ -18,11 +21,18 @@ class Overlay(Primitive):

Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
label (str | None): Optional label name to overlay.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
def __init__(
self,
image: PIL.Image | np.ndarray,
opacity: float = 0.4,
label: Union[str, None] = None,
) -> None:
self.image = self._to_pil(image)
self.label = label
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
Expand All @@ -33,3 +43,22 @@ def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
def compute(self, image: PIL.Image) -> PIL.Image:
image_ = self.image.resize(image.size)
return PIL.Image.blend(image, image_, self.opacity)

@classmethod
def overlay_labels(cls, image: PIL.Image, labels: Union[list[str], str, None] = None) -> PIL.Image:
"""Draw labels at the bottom center of the image.

This is handy when you want to add a label to the image.
"""
if labels is not None:
labels = [labels] if isinstance(labels, str) else labels
font = ImageFont.load_default(size=18)
buffer_y = 5
dummy_image = PIL.Image.new("RGB", (1, 1))
draw = PIL.ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), ", ".join(labels), font=font)
image_ = PIL.Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), "white")
draw = PIL.ImageDraw.Draw(image_)
draw.text((0, 0), ", ".join(labels), font=font, fill="black")
image.paste(image_, (image.width // 2 - image_.width // 2, image.height - image_.height - buffer_y))
return image
28 changes: 24 additions & 4 deletions src/python/model_api/visualizer/scene/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from typing import Union

import cv2
from PIL import Image

from model_api.models.result import ClassificationResult
from model_api.visualizer.layout import Flatten, Layout
from model_api.visualizer.primitive import Overlay
from model_api.visualizer.primitive import Label, Overlay

from .scene import Scene

Expand All @@ -18,9 +19,28 @@ class ClassificationScene(Scene):
"""Classification Scene."""

def __init__(self, image: Image, result: ClassificationResult, layout: Union[Layout, None] = None) -> None:
self.image = image
self.result = result
super().__init__(
base=image,
label=self._get_labels(result),
overlay=self._get_overlays(result),
layout=layout,
)

def _get_labels(self, result: ClassificationResult) -> list[Label]:
labels = []
if result.top_labels is not None and len(result.top_labels) > 0:
for label in result.top_labels:
if label.name is not None:
labels.append(Label(label=label.name, score=label.confidence))
return labels

def _get_overlays(self, result: ClassificationResult) -> list[Overlay]:
overlays = []
if result.saliency_map is not None and result.saliency_map.size > 0:
saliency_map = cv2.cvtColor(result.saliency_map, cv2.COLOR_BGR2RGB)
overlays.append(Overlay(saliency_map))
return overlays

@property
def default_layout(self) -> Layout:
return Flatten(Overlay)
return Flatten(Overlay, Label)
34 changes: 31 additions & 3 deletions src/python/model_api/visualizer/scene/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from typing import Union

import cv2
from PIL import Image

from model_api.models.result import DetectionResult
from model_api.visualizer.layout import Layout
from model_api.visualizer.layout import Flatten, HStack, Layout
from model_api.visualizer.primitive import BoundingBox, Label, Overlay

from .scene import Scene

Expand All @@ -17,5 +19,31 @@ class DetectionScene(Scene):
"""Detection Scene."""

def __init__(self, image: Image, result: DetectionResult, layout: Union[Layout, None] = None) -> None:
self.image = image
self.result = result
super().__init__(
base=image,
bounding_box=self._get_bounding_boxes(result),
overlay=self._get_overlays(result),
layout=layout,
)

def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
overlays = []
# Add only the overlays that are predicted
label_index_mapping = dict(zip(result.labels, result.label_names))
for label_index, label_name in label_index_mapping.items():
# Index 0 as it assumes only one batch
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
overlays.append(Overlay(saliency_map, label=label_name.title()))
return overlays

def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:
bounding_boxes = []
for score, label_name, bbox in zip(result.scores, result.label_names, result.bboxes):
x1, y1, x2, y2 = bbox
label = f"{label_name} ({score:.2f})"
bounding_boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label))
return bounding_boxes

@property
def default_layout(self) -> Layout:
return HStack(Flatten(BoundingBox, Label), Overlay)
33 changes: 32 additions & 1 deletion tests/python/unit/visualizer/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
from PIL import Image

from model_api.models.result import AnomalyResult
from model_api.models.result import AnomalyResult, ClassificationResult, DetectionResult
from model_api.models.result.classification import Label
from model_api.visualizer import Visualizer


Expand All @@ -32,3 +33,33 @@ def test_anomaly_scene(mock_image: Image, tmpdir: Path):
visualizer = Visualizer()
visualizer.save(mock_image, anomaly_result, tmpdir / "anomaly_scene.jpg")
assert Path(tmpdir / "anomaly_scene.jpg").exists()


def test_classification_scene(mock_image: Image, tmpdir: Path):
"""Test if the classification scene is created."""
classification_result = ClassificationResult(
top_labels=[
Label(name="cat", confidence=0.95),
Label(name="dog", confidence=0.90),
],
saliency_map=np.ones(mock_image.size, dtype=np.uint8),
)
visualizer = Visualizer()
visualizer.save(
mock_image, classification_result, tmpdir / "classification_scene.jpg"
)
assert Path(tmpdir / "classification_scene.jpg").exists()


def test_detection_scene(mock_image: Image, tmpdir: Path):
"""Test if the detection scene is created."""
detection_result = DetectionResult(
bboxes=np.array([[0, 0, 128, 128], [32, 32, 96, 96]]),
labels=np.array([0, 1]),
label_names=["person", "car"],
scores=np.array([0.85, 0.75]),
saliency_map=(np.ones((1, 2, 6, 8)) * 255).astype(np.uint8),
)
visualizer = Visualizer()
visualizer.save(mock_image, detection_result, tmpdir / "detection_scene.jpg")
assert Path(tmpdir / "detection_scene.jpg").exists()
Loading