-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from shamspias/feat/sahi
Feat/sahi
- Loading branch information
Showing
8 changed files
with
153 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,74 @@ | ||
from typing import Optional | ||
|
||
|
||
class BaseFormat: | ||
def __init__(self, output_dir): | ||
""" | ||
Base class for handling annotation formats. This class provides basic functionalities | ||
like saving annotations and ensuring directory structure, which can be extended by subclasses. | ||
Attributes: | ||
output_dir (str): The directory where the output will be stored. | ||
sahi_enabled (bool): Flag to enable or disable SAHI (Sliced Inference). | ||
sahi_utils (Optional[object]): SAHI utility object for performing sliced inference. | ||
""" | ||
|
||
def __init__(self, output_dir: str, sahi_enabled: bool = False, sahi_utils: Optional[object] = None): | ||
""" | ||
Initializes the BaseFormat class with output directory and optional SAHI settings. | ||
Args: | ||
output_dir (str): Path to the directory where annotations will be saved. | ||
sahi_enabled (bool): Boolean flag to enable SAHI (Sliced Inference). Defaults to False. | ||
sahi_utils (Optional[object]): Instance of SAHI utility class to be used for sliced inference. Defaults to None. | ||
""" | ||
self.output_dir = output_dir | ||
self.image_dir = None | ||
self.label_dir = None | ||
self.sahi_enabled = sahi_enabled | ||
self.sahi_utils = sahi_utils | ||
|
||
def ensure_directories(self): | ||
""" | ||
Ensures that necessary directories are created. | ||
Ensures that the necessary directories for saving annotations exist. | ||
Must be implemented by subclasses. | ||
Raises: | ||
NotImplementedError: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplementedError("Subclasses should implement this method.") | ||
|
||
def annotate_frame(self, frame, frame_path, frame_filename, model_conf, supported_classes): | ||
def save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list): | ||
""" | ||
Annotates the frame using the model output and saves the annotation in a format-specific manner. | ||
Saves the annotations for a given frame. If SAHI is enabled, performs sliced inference before saving. | ||
Args: | ||
frame (ndarray): The image frame for which annotations are being saved. | ||
frame_path (str): The path where the frame is located. | ||
frame_filename (str): The name of the frame file. | ||
results (list): A list of results from the detection model or sliced inference. | ||
supported_classes (list): List of supported class labels for the annotations. | ||
Parameters: | ||
frame (np.array): The frame to be annotated. | ||
frame_path (str): Path where the frame image is saved. | ||
frame_filename (str): Filename of the frame image. | ||
model_conf (float): Model confidence threshold for annotations. | ||
supported_classes (list): List of supported class names. | ||
Raises: | ||
NotImplementedError: If `_save_annotations` is not implemented in the subclass. | ||
""" | ||
raise NotImplementedError("Subclasses should implement this method.") | ||
if self.sahi_enabled and self.sahi_utils: | ||
if hasattr(self.sahi_utils, 'perform_sliced_inference'): | ||
results = self.sahi_utils.perform_sliced_inference(frame) | ||
else: | ||
raise AttributeError("sahi_utils object does not have 'perform_sliced_inference' method.") | ||
self._save_annotations(frame, frame_path, frame_filename, results, supported_classes) | ||
|
||
def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes): | ||
""" Method to save annotations; implemented in subclasses. """ | ||
def _save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list): | ||
""" | ||
Abstract method for saving annotations. To be implemented by subclasses to define | ||
the logic for saving the annotations. | ||
Args: | ||
frame (ndarray): The image frame for which annotations are being saved. | ||
frame_path (str): The path where the frame is located. | ||
frame_filename (str): The name of the frame file. | ||
results (list): A list of results from the detection model or sliced inference. | ||
supported_classes (list): List of supported class labels for the annotations. | ||
Raises: | ||
NotImplementedError: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplementedError("Subclasses should implement this method.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,5 @@ ultralytics | |
pillow | ||
PyYAML | ||
opencv-python | ||
boto3 | ||
boto3 | ||
sahi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from sahi.predict import get_sliced_prediction | ||
from sahi import AutoDetectionModel | ||
|
||
|
||
class SahiUtils: | ||
def __init__(self, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), overlap_ratio=(0.2, 0.2)): | ||
self.device = device # CPU or 'cuda:0' | ||
self.model_type = model_type | ||
self.model = self.load_model(model_path) | ||
self.slice_size = slice_size | ||
self.overlap_ratio = overlap_ratio | ||
|
||
def load_model(self, model_path): | ||
detection_model = AutoDetectionModel.from_pretrained( | ||
model_type=self.model_type, | ||
model_path=model_path, | ||
confidence_threshold=0.1, | ||
device=self.device, | ||
) | ||
return detection_model | ||
|
||
def perform_sliced_inference(self, image): | ||
# Perform sliced inference using the loaded model and SAHI | ||
results = get_sliced_prediction( | ||
image, | ||
self.model, # this should be a sahi model | ||
slice_height=self.slice_size[0], | ||
slice_width=self.slice_size[1], | ||
overlap_height_ratio=self.overlap_ratio[0], | ||
overlap_width_ratio=self.overlap_ratio[1] | ||
) | ||
return results |