Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Visualizer class #426

Merged
merged 9 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ API Reference
Data models <geti_sdk.data_models>
Import Export module <geti_sdk.import_export>
Deployment <geti_sdk.deployment>
Prediction Visualization <geti_sdk.prediction_visualization>
HTTP session <geti_sdk.http_session>
REST converters <geti_sdk.rest_converters>
REST clients <geti_sdk.rest_clients>
Expand Down
3 changes: 2 additions & 1 deletion geti_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@
"""

from .geti import Geti
from .prediction_visualization.visualizer import Visualizer

__version__ = "2.1.0"

__all__ = ["Geti"]
__all__ = ["Geti", "Visualizer"]
12 changes: 6 additions & 6 deletions geti_sdk/benchmarking/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
Video,
)
from geti_sdk.deployment import Deployment
from geti_sdk.prediction_visualization.visualizer import Visualizer
from geti_sdk.rest_clients import ImageClient, ModelClient, TrainingClient, VideoClient
from geti_sdk.rest_clients.prediction_client import PredictionClient
from geti_sdk.utils.plot_helpers import (
concat_prediction_results,
pad_image_and_put_caption,
show_image_with_annotation_scene,
)

from .utils import get_system_info, load_benchmark_media, suppress_log_output
Expand Down Expand Up @@ -859,6 +859,8 @@ def compare_predictions(
with open(throughput_benchmark_results, "r") as results_file:
throughput_benchmark_results = list(csv.DictReader(results_file))

visualizer = Visualizer()

# Performe inferece
with logging_redirect_tqdm(tqdm_class=tqdm):
results: List[List[np.ndarray]] = []
Expand Down Expand Up @@ -890,9 +892,7 @@ def compare_predictions(
f"failed. Inference failed with error: `{e}`"
)
if success:
image_with_prediction = show_image_with_annotation_scene(
image, prediction, show_results=False
)
image_with_prediction = visualizer.draw(image, prediction)
image_with_prediction = cv2.cvtColor(
image_with_prediction, cv2.COLOR_BGR2RGB
)
Expand Down Expand Up @@ -953,8 +953,8 @@ def compare_predictions(
if include_online_prediction_for_active_model:
logging.info("Predicting on the platform using the active model")
online_prediction_result = self._predict_using_active_model(image)
image_with_prediction = show_image_with_annotation_scene(
image, online_prediction_result["prediction"], show_results=False
image_with_prediction = visualizer.draw(
image, online_prediction_result["prediction"]
)
image_with_prediction = cv2.cvtColor(
image_with_prediction, cv2.COLOR_BGR2RGB
Expand Down
6 changes: 4 additions & 2 deletions geti_sdk/deployment/resources/OVMS_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictions = deployment.infer(image=image)

# Show inference result
from geti_sdk.utils import show_image_with_annotation_scene
show_image_with_annotation_scene(image=image, annotation_scene=predictions);
from geti_sdk import Visualizer
visualizer = Visualizer()
result_image = visualizer.draw(image=image, annotation_scene=predictions)
visualizer.show_window(result_image)
```

The example uses a sample image, please make sure to replace it with your own.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from geti_sdk.data_models import Prediction
from geti_sdk.deployment.inference_hook_interfaces import PostInferenceAction
from geti_sdk.prediction_visualization.visualizer import Visualizer
from geti_sdk.rest_converters import PredictionRESTConverter
from geti_sdk.utils import show_image_with_annotation_scene


class FileSystemDataCollection(PostInferenceAction):
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
self.save_predictions = save_predictions
self.save_scores = save_scores
self.save_overlays = save_overlays
self.visualizer = Visualizer()

self._repr_info_ = (
f"target_folder=`{target_folder}`, "
Expand Down Expand Up @@ -147,12 +148,8 @@ def __call__(

if self.save_overlays:
overlay_path = os.path.join(self.overlays_path, filename + ".jpg")
show_image_with_annotation_scene(
image=image,
annotation_scene=prediction,
filepath=overlay_path,
show_results=False,
)
result = self.visualizer.draw(image, prediction)
self.visualizer.save_image(result, overlay_path)
except Exception as e:
logging.exception(e, stack_info=True, exc_info=True)

Expand Down
44 changes: 42 additions & 2 deletions geti_sdk/prediction_visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,48 @@
Introduction
------------

The `prediction_visualization` package provides classes for visualizing models predictions.
Currently, the user interfaces to this package are available in the :py:mod:`~geti_sdk.utils.plot_helpers` module.
The `prediction_visualization` package provides classes for visualizing models predictions and media annotations.
Aditionally, shortend interface to this package is available through the :py:mod:`~geti_sdk.utils.plot_helpers` module.

The main :py:class:`~geti_sdk.prediction_visualization.visualizer.Visualizer` class is a flexible utility class for working
with Geti-SDK Prediction and Annotation object. You can initialize the Visualizer with the desired settings and then use it to draw
the annotations on the input image.

.. code-block:: python

from geti_sdk import Visualizer

visualizer = Visualizer(
show_labels=True,
show_confidence=True,
show_count=False,
)

# Obtain a prediction from the Intel Geti platfor server or a local deployment.
...

# Visualize the prediction on the input image.
result = visualizer.draw(
numpy_image,
prediction,
fill_shapes=True,
confidence_threshold=0.4,
)
visualizer.show_in_notebook(result)

In case the Prediction was generated with a model that supports explainable AI functionality, the Visualizer can also display
the explanation for the prediction.

.. code-block:: python
image_with_saliency_map = visualizer.explain_label(
numpy_image,
prediction,
label_name="Cat",
opacity=0.5,
show_predictions=True,
)
visualizer.save_image(image_with_saliency_map, "./explained_prediction.jpg")
visualizer.show_window(image_with_saliency_map) # When called in a script

Module contents
---------------
Expand Down
70 changes: 56 additions & 14 deletions geti_sdk/prediction_visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
"""The package provides the Visualizer class for models predictions visualization."""


from typing import Optional
from os import PathLike
from typing import List, Optional, Union

import cv2
import numpy as np
from IPython.display import display
from PIL import Image

from geti_sdk.data_models.annotation_scene import AnnotationScene
from geti_sdk.data_models.containers.media_list import MediaList
from geti_sdk.data_models.media import VideoFrame
from geti_sdk.data_models.predictions import Prediction
from geti_sdk.prediction_visualization.shape_drawer import ShapeDrawer

Expand All @@ -44,8 +49,6 @@ def __init__(
show_confidence: bool = True,
show_count: bool = False,
is_one_label: bool = False,
delay: Optional[int] = None,
output: Optional[str] = None,
) -> None:
"""
Initialize the Visualizer.
Expand All @@ -55,19 +58,12 @@ def __init__(
:param show_confidence: Show confidence on the output image
:param show_count: Show count of the shapes on the output image
:param is_one_label: Show only one label on the output image
:param delay: Delay time for the output image
:param output: Path to save the output image
"""
self.window_name = "Window" if window_name is None else window_name
self.shape_drawer = ShapeDrawer(
show_count, is_one_label, show_labels, show_confidence
)

self.delay = delay
if delay is None:
self.delay = 1
self.output = output

def draw(
self,
image: np.ndarray,
Expand All @@ -90,7 +86,7 @@ def draw(
if confidence_threshold is not None:
annotation = annotation.filter_by_confidence(confidence_threshold)
result = self.shape_drawer.draw(
image, annotation, labels=[], fill_shapes=fill_shapes
image.copy(), annotation, labels=[], fill_shapes=fill_shapes
)
return result

Expand Down Expand Up @@ -140,7 +136,53 @@ def explain_label(
result = self.draw(result, filtered_prediction, fill_shapes=False)
return result

def show(self, image: np.ndarray) -> None:
@staticmethod
def save_image(image: np.ndarray, output_path: PathLike) -> None:
"""
Save the image to the output path.

:param image: Image in RGB format to be saved
:param output_path: Path to save the image
"""
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, bgr_image)

@staticmethod
def save_video(
video_frames: MediaList[VideoFrame],
annotation_scenes: List[Union[AnnotationScene, Prediction]],
output_path: PathLike,
fps: float = 1,
) -> None:
"""
Save the video to the output path.

:param video_frames: List of video frames
:param annotation_scenes: List of annotation scenes to be drawn on the video frames
:param output_path: Path to save the image
"""
out_writer = cv2.VideoWriter(
filename=f"{output_path}",
fourcc=cv2.VideoWriter_fourcc("M", "J", "P", "G"),
fps=fps,
frameSize=(
video_frames[0].media_information.width,
video_frames[0].media_information.height,
),
)
for frame, annotation in zip(video_frames, annotation_scenes):
out_writer.write(frame)

@staticmethod
def show_in_notebook(image: np.ndarray) -> None:
"""
Show the image in the Jupyter notebook.

:param image: Image to be shown in RGB format
"""
display(Image.fromarray(image))

def show_window(self, image: np.ndarray) -> None:
"""
Show result image.

Expand All @@ -149,6 +191,6 @@ def show(self, image: np.ndarray) -> None:
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow(self.window_name, image_bgr)

def is_quit(self) -> bool:
def is_quit(self, delay: int = 1) -> bool:
"""Check user wish to quit."""
return ord("q") == cv2.waitKey(self.delay)
return ord("q") == cv2.waitKey(delay)
28 changes: 12 additions & 16 deletions notebooks/003_upload_and_predict_image.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,19 @@
"metadata": {},
"outputs": [],
"source": [
"from geti_sdk.utils import show_image_with_annotation_scene\n",
"import cv2\n",
"\n",
"from geti_sdk import Visualizer\n",
"\n",
"# To visualise the image, we have to retrieve the pixel data from the platform using the `image.get_data` method. The actual pixel data is\n",
"# downloaded and cached only on the first call to this method\n",
"image.get_data(geti.session)\n",
"numpy_image = image.numpy\n",
"\n",
"show_image_with_annotation_scene(\n",
" image, prediction, show_in_notebook=True, channel_order=\"bgr\"\n",
");"
"visualizer = Visualizer()\n",
"image_rgb = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB)\n",
"result = visualizer.draw(image_rgb, prediction)\n",
"visualizer.show_in_notebook(result)"
]
},
{
Expand All @@ -240,18 +244,10 @@
" visualise_output=False,\n",
" delete_after_prediction=False,\n",
")\n",
"show_image_with_annotation_scene(\n",
" quick_image, quick_prediction, show_in_notebook=True, channel_order=\"bgr\"\n",
");"
"quick_image_rgb = cv2.cvtColor(quick_image.numpy, cv2.COLOR_BGR2RGB)\n",
"quick_result = visualizer.draw(quick_image_rgb, quick_prediction)\n",
"visualizer.show_in_notebook(quick_result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51090376-c85e-4af3-9ff8-b030934fd095",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -270,7 +266,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
21 changes: 11 additions & 10 deletions notebooks/005_modify_image.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@
"metadata": {},
"outputs": [],
"source": [
"from geti_sdk import Visualizer\n",
"from geti_sdk.rest_clients import AnnotationClient\n",
"from geti_sdk.utils import show_image_with_annotation_scene\n",
"\n",
"annotation_client = AnnotationClient(\n",
" session=geti.session, workspace_id=geti.workspace_id, project=project\n",
Expand All @@ -231,9 +231,11 @@
"\n",
"# Inspect the annotation\n",
"print(annotation.overview)\n",
"show_image_with_annotation_scene(\n",
" image, annotation, show_in_notebook=True, channel_order=\"bgr\"\n",
");"
"\n",
"visualizer = Visualizer()\n",
"image_rgb = cv2.cvtColor(image.numpy, cv2.COLOR_BGR2RGB)\n",
"result = visualizer.draw(image_rgb, annotation)\n",
"visualizer.show_in_notebook(result)"
]
},
{
Expand Down Expand Up @@ -276,11 +278,10 @@
")\n",
"\n",
"# Inspect the annotation\n",
"show_image_with_annotation_scene(\n",
" grayscale_image.get_data(geti.session),\n",
" grayscale_annotation,\n",
" show_in_notebook=True,\n",
");"
"result = visualizer.draw(\n",
" grayscale_image.get_data(geti.session).numpy, grayscale_annotation\n",
")\n",
"visualizer.show_in_notebook(result)"
]
},
{
Expand Down Expand Up @@ -342,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
Loading
Loading