Skip to content

Commit

Permalink
Merge pull request #4 from shamspias/feat/sahi
Browse files Browse the repository at this point in the history
Feat/sahi
  • Loading branch information
shamspias authored Sep 5, 2024
2 parents 1f0b78b + a9aeafa commit 81fee4c
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 22 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,6 @@ how to make contributions.

## License

Distributed under the MIT License. See [LICENSE](LICENSE) for more information.
Distributed under the MIT License. See [LICENSE](LICENSE) for more information.

Powered by [Indikat](https://indikat.tech)
23 changes: 22 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings
from typing import Optional
from pydantic import field_validator
from typing import Optional, Tuple


class Config(BaseSettings):
Expand All @@ -17,6 +18,26 @@ class Config(BaseSettings):
s3_bucket_name: Optional[str] = ""
s3_region_name: Optional[str] = ""

# SAHI settings
sahi_enabled: Optional[bool] = False
sahi_model_type: Optional[str] = 'yolov8'
sahi_device: Optional[str] = 'cpu'
sahi_slice_size: Optional[Tuple[int, int]] = (256, 256)
sahi_overlap_ratio: Optional[Tuple[float, float]] = (0.2, 0.2)

# Use field_validator for Pydantic v2
@field_validator("sahi_slice_size", mode='before')
def parse_sahi_slice_size(cls, v):
if isinstance(v, str):
return tuple(map(int, v.split(',')))
return v

@field_validator("sahi_overlap_ratio", mode='before')
def parse_sahi_overlap_ratio(cls, v):
if isinstance(v, str):
return tuple(map(float, v.split(',')))
return v

class Config:
env_file = ".env"
env_file_encoding = 'utf-8'
10 changes: 8 additions & 2 deletions app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ultralytics import YOLO
import yaml
from utils.image_processor import ImageProcessor
from utils.sahi_utils import SahiUtils


class VideoFrameExtractor:
Expand All @@ -12,7 +13,7 @@ class VideoFrameExtractor:
"""

def __init__(self, config, video_path, frame_rate, output_dir, model_path, class_config_path, output_format,
transformations):
transformations, sahi_config=None):
self.config = config
self.video_path = video_path # Ensure this is a string representing the path to the video file.
self.frame_rate = frame_rate
Expand All @@ -24,6 +25,10 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class
self.supported_classes = self.load_classes(self.class_config_path)
self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640)))

self.sahi_utils = SahiUtils(os.path.join('models', model_path), **sahi_config) if sahi_config else None
self.output_format.sahi_enabled = bool(sahi_config)
self.output_format.sahi_utils = self.sahi_utils

# Debugging output to ensure path handling
if not os.path.exists(self.video_path):
raise FileNotFoundError(f"The specified video file was not found at {self.video_path}")
Expand Down Expand Up @@ -62,8 +67,9 @@ def extract_frames(self, model_confidence):

for key, transformed_image in transformed_images.items():
if transformed_image.ndim == 2: # Check if the image is grayscale
# Convert back to RGB format for consistency
transformed_image = cv2.cvtColor(transformed_image,
cv2.COLOR_GRAY2BGR) # Convert back to RGB format for consistency
cv2.COLOR_GRAY2BGR)

frame_filename = f"{self._get_video_basename()}_image{frame_count}_{key}.jpg"
frame_path = os.path.join(self.output_dir, 'images', frame_filename)
Expand Down
20 changes: 18 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class VideoLabelApp:
def __init__(self):
self.sahi_config = None
self.config = Config()
self.storage_manager = StorageManager(self.config)
self.format_options = {'Roboflow': RoboflowFormat, 'CVAT': CVATFormat}
Expand Down Expand Up @@ -44,7 +45,6 @@ def continue_ui(self):
models = [file for file in os.listdir(self.config.models_directory) if file.endswith('.pt')]
self.model_selection = st.selectbox("Choose a model:", models)
self.frame_rate = st.number_input("Frame rate", value=self.config.default_frame_rate)
self.model_confidence = st.number_input("Model Confidence", value=0.1)
transformation_options = st.multiselect('Select image transformations:',
['Resize', 'Grayscale', 'Rotate 90 degrees'])
self.transformations = {
Expand All @@ -53,6 +53,22 @@ def continue_ui(self):
'rotate': 'Rotate 90 degrees' in transformation_options
}
self.format_selection = st.selectbox("Choose output format:", list(self.format_options.keys()))
self.sahi_enabled = st.sidebar.checkbox("Enable SAHI", value=self.config.sahi_enabled)
if self.sahi_enabled:
self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", "yolov9", "yolov10"])
self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu"])
self.config.sahi_slice_size = st.sidebar.slider("SAHI slice size:", 128, 512, (256, 256))
self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, 0.2)
self.sahi_config = {
'model_type': self.config.sahi_model_type,
'slice_size': self.config.sahi_slice_size,
'overlap_ratio': self.config.sahi_overlap_ratio,
'device': self.config.sahi_device # Can be updated to use GPU if available
}
else:
self.sahi_config = None
self.model_confidence = st.number_input("Model Confidence", value=0.1)

if st.button('Extract Frames'):
self.process_video()

Expand Down Expand Up @@ -99,7 +115,7 @@ def run_extraction(self, video_path, unique_filename):
try:
extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
self.model_selection, class_config_path, output_format_instance,
self.transformations)
self.transformations, self.sahi_config)
extractor.extract_frames(self.model_confidence)
if self.format_selection == "CVAT":
output_format_instance.zip_and_cleanup()
Expand Down
7 changes: 7 additions & 0 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ S3_ACCESS_KEY=your_access_key
S3_SECRET_KEY=your_secret_key
S3_BUCKET_NAME=your_bucket_name
S3_REGION_NAME=us-east-1

# SAHI Configuration
SAHI_ENABLED=False
SAHI_MODEL_TYPE=yolov8
SAHI_DEVICE=cpu
SAHI_SLICE_SIZE=256,256
SAHI_OVERLAP_RATIO=0.2,0.2
76 changes: 61 additions & 15 deletions formats/base_format.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,74 @@
from typing import Optional


class BaseFormat:
def __init__(self, output_dir):
"""
Base class for handling annotation formats. This class provides basic functionalities
like saving annotations and ensuring directory structure, which can be extended by subclasses.
Attributes:
output_dir (str): The directory where the output will be stored.
sahi_enabled (bool): Flag to enable or disable SAHI (Sliced Inference).
sahi_utils (Optional[object]): SAHI utility object for performing sliced inference.
"""

def __init__(self, output_dir: str, sahi_enabled: bool = False, sahi_utils: Optional[object] = None):
"""
Initializes the BaseFormat class with output directory and optional SAHI settings.
Args:
output_dir (str): Path to the directory where annotations will be saved.
sahi_enabled (bool): Boolean flag to enable SAHI (Sliced Inference). Defaults to False.
sahi_utils (Optional[object]): Instance of SAHI utility class to be used for sliced inference. Defaults to None.
"""
self.output_dir = output_dir
self.image_dir = None
self.label_dir = None
self.sahi_enabled = sahi_enabled
self.sahi_utils = sahi_utils

def ensure_directories(self):
"""
Ensures that necessary directories are created.
Ensures that the necessary directories for saving annotations exist.
Must be implemented by subclasses.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Subclasses should implement this method.")

def annotate_frame(self, frame, frame_path, frame_filename, model_conf, supported_classes):
def save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list):
"""
Annotates the frame using the model output and saves the annotation in a format-specific manner.
Saves the annotations for a given frame. If SAHI is enabled, performs sliced inference before saving.
Args:
frame (ndarray): The image frame for which annotations are being saved.
frame_path (str): The path where the frame is located.
frame_filename (str): The name of the frame file.
results (list): A list of results from the detection model or sliced inference.
supported_classes (list): List of supported class labels for the annotations.
Parameters:
frame (np.array): The frame to be annotated.
frame_path (str): Path where the frame image is saved.
frame_filename (str): Filename of the frame image.
model_conf (float): Model confidence threshold for annotations.
supported_classes (list): List of supported class names.
Raises:
NotImplementedError: If `_save_annotations` is not implemented in the subclass.
"""
raise NotImplementedError("Subclasses should implement this method.")
if self.sahi_enabled and self.sahi_utils:
if hasattr(self.sahi_utils, 'perform_sliced_inference'):
results = self.sahi_utils.perform_sliced_inference(frame)
else:
raise AttributeError("sahi_utils object does not have 'perform_sliced_inference' method.")
self._save_annotations(frame, frame_path, frame_filename, results, supported_classes)

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
""" Method to save annotations; implemented in subclasses. """
def _save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list):
"""
Abstract method for saving annotations. To be implemented by subclasses to define
the logic for saving the annotations.
Args:
frame (ndarray): The image frame for which annotations are being saved.
frame_path (str): The path where the frame is located.
frame_filename (str): The name of the frame file.
results (list): A list of results from the detection model or sliced inference.
supported_classes (list): List of supported class labels for the annotations.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Subclasses should implement this method.")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ ultralytics
pillow
PyYAML
opencv-python
boto3
boto3
sahi
32 changes: 32 additions & 0 deletions utils/sahi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from sahi.predict import get_sliced_prediction
from sahi import AutoDetectionModel


class SahiUtils:
def __init__(self, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), overlap_ratio=(0.2, 0.2)):
self.device = device # CPU or 'cuda:0'
self.model_type = model_type
self.model = self.load_model(model_path)
self.slice_size = slice_size
self.overlap_ratio = overlap_ratio

def load_model(self, model_path):
detection_model = AutoDetectionModel.from_pretrained(
model_type=self.model_type,
model_path=model_path,
confidence_threshold=0.1,
device=self.device,
)
return detection_model

def perform_sliced_inference(self, image):
# Perform sliced inference using the loaded model and SAHI
results = get_sliced_prediction(
image,
self.model, # this should be a sahi model
slice_height=self.slice_size[0],
slice_width=self.slice_size[1],
overlap_height_ratio=self.overlap_ratio[0],
overlap_width_ratio=self.overlap_ratio[1]
)
return results

0 comments on commit 81fee4c

Please sign in to comment.