Skip to content

Commit

Permalink
Merge pull request #5 from shamspias/feat/sahi
Browse files Browse the repository at this point in the history
Feat/sahi -> Implement write annotation and different save
  • Loading branch information
shamspias authored Sep 8, 2024
2 parents 81fee4c + 068dc5a commit 8277bd3
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 130 deletions.
29 changes: 16 additions & 13 deletions app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class
self.supported_classes = self.load_classes(self.class_config_path)
self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640)))

self.sahi_utils = SahiUtils(os.path.join('models', model_path), **sahi_config) if sahi_config else None
self.output_format.sahi_enabled = bool(sahi_config)
self.output_format.sahi_utils = self.sahi_utils
# Only initialize SahiUtils if SAHI is enabled
if sahi_config:
self.sahi_utils = SahiUtils(os.path.join('models', model_path), **sahi_config)
else:
self.sahi_utils = None

# Debugging output to ensure path handling
if not os.path.exists(self.video_path):
Expand All @@ -46,9 +48,6 @@ def load_classes(self, config_path):
return [cls['name'] for cls in class_data['classes']]

def extract_frames(self, model_confidence):
"""
Extract and process frames from the video, and save them using the specified output format.
"""
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video stream for {self.video_path}")
Expand All @@ -66,18 +65,22 @@ def extract_frames(self, model_confidence):
transformed_images = self.apply_transformations(frame)

for key, transformed_image in transformed_images.items():
if transformed_image.ndim == 2: # Check if the image is grayscale
# Convert back to RGB format for consistency
transformed_image = cv2.cvtColor(transformed_image,
cv2.COLOR_GRAY2BGR)
if transformed_image.ndim == 2: # Grayscale to RGB for consistency
transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_GRAY2BGR)

frame_filename = f"{self._get_video_basename()}_image{frame_count}_{key}.jpg"
frame_path = os.path.join(self.output_dir, 'images', frame_filename)

# Save images locally or to configured storage
cv2.imwrite(frame_path, transformed_image)
results = self.yolo_model.predict(transformed_image, conf=model_confidence)
self.output_format.save_annotations(transformed_image, frame_path, frame_filename, results,
if self.sahi_utils:
results = self.sahi_utils.perform_sliced_inference(transformed_image)
else:
results = self.yolo_model.predict(transformed_image, conf=model_confidence, verbose=False)

# print(results)

self.output_format.save_annotations(transformed_image, frame_path, frame_filename,
results,
self.supported_classes)

frame_count += 1
Expand Down
5 changes: 3 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def continue_ui(self):
self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", "yolov9", "yolov10"])
self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu"])
self.config.sahi_slice_size = st.sidebar.slider("SAHI slice size:", 128, 512, (256, 256))
self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, 0.2)
self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, (0.2, 0.2))
self.sahi_config = {
'model_type': self.config.sahi_model_type,
'slice_size': self.config.sahi_slice_size,
Expand Down Expand Up @@ -111,7 +111,8 @@ def run_extraction(self, video_path, unique_filename):
class_config_path = os.path.join(self.config.object_class_directory, self.class_config_selection)
specific_output_dir = os.path.join(self.config.output_directory, unique_filename)
os.makedirs(specific_output_dir, exist_ok=True)
output_format_instance = self.format_options[self.format_selection](specific_output_dir)
output_format_instance = self.format_options[self.format_selection](output_dir=specific_output_dir,
sahi_enabled=self.sahi_enabled)
try:
extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
self.model_selection, class_config_path, output_format_instance,
Expand Down
77 changes: 53 additions & 24 deletions formats/base_format.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Optional
from typing import Optional, List, Dict


class BaseFormat:
"""
Base class for handling annotation formats. This class provides basic functionalities
like saving annotations and ensuring directory structure, which can be extended by subclasses.
Base class for handling annotation formats. Provides foundational functionalities
like saving annotations and ensuring directory structure, designed for extension by subclasses.
Attributes:
output_dir (str): The directory where the output will be stored.
output_dir (str): Directory where output will be stored.
sahi_enabled (bool): Flag to enable or disable SAHI (Sliced Inference).
sahi_utils (Optional[object]): SAHI utility object for performing sliced inference.
"""
Expand All @@ -25,38 +25,67 @@ def __init__(self, output_dir: str, sahi_enabled: bool = False, sahi_utils: Opti
self.sahi_enabled = sahi_enabled
self.sahi_utils = sahi_utils

def write_annotations(self, frame_filename: str, annotations: List[str]):
"""
Writes annotations to a file based on the frame filename.
Args:
frame_filename (str): The filename of the frame to which annotations relate.
annotations (List[str]): Annotations to be written to the file.
"""
raise NotImplementedError("Subclasses should implement this method.")

def ensure_directories(self):
"""
Ensures that the necessary directories for saving annotations exist.
Ensures that necessary directories for saving annotations exist.
Must be implemented by subclasses.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Subclasses should implement this method.")

def save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list):
def process_results(self, frame, results: Dict, img_dimensions) -> List[str]:
"""
Saves the annotations for a given frame. If SAHI is enabled, performs sliced inference before saving.
Generate formatted strings from detection results suitable for annotations.
Args:
frame (ndarray): The image frame for which annotations are being saved.
frame_path (str): The path where the frame is located.
frame_filename (str): The name of the frame file.
results (list): A list of results from the detection model or sliced inference.
supported_classes (list): List of supported class labels for the annotations.
frame: The image frame being processed.
results: Detection results containing bounding boxes and class IDs.
img_dimensions: Dimensions of the image for normalizing coordinates.
Raises:
NotImplementedError: If `_save_annotations` is not implemented in the subclass.
Returns:
List of annotation strings formatted according to specific requirements.
"""
if self.sahi_enabled and self.sahi_utils:
if hasattr(self.sahi_utils, 'perform_sliced_inference'):
results = self.sahi_utils.perform_sliced_inference(frame)
else:
raise AttributeError("sahi_utils object does not have 'perform_sliced_inference' method.")
self._save_annotations(frame, frame_path, frame_filename, results, supported_classes)

def _save_annotations(self, frame, frame_path: str, frame_filename: str, results: list, supported_classes: list):
annotations = []
img_height, img_width = img_dimensions

# Check if SAHI is enabled to adapt processing of results accordingly
if self.sahi_enabled:
for box in results['boxes']: # Assuming SAHI results are formatted similarly
class_id = int(box['cls'][0])
xmin, ymin, xmax, ymax = box['xyxy'][0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
else:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")

return annotations

def save_annotations(self, frame, frame_path: str, frame_filename: str, results: Dict,
supported_classes: List[str]):
"""
Abstract method for saving annotations. To be implemented by subclasses to define
the logic for saving the annotations.
Expand All @@ -65,8 +94,8 @@ def _save_annotations(self, frame, frame_path: str, frame_filename: str, results
frame (ndarray): The image frame for which annotations are being saved.
frame_path (str): The path where the frame is located.
frame_filename (str): The name of the frame file.
results (list): A list of results from the detection model or sliced inference.
supported_classes (list): List of supported class labels for the annotations.
results (Dict): A dictionary of results from the detection model or sliced inference.
supported_classes (List[str]): List of supported class labels for the annotations.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
Expand Down
121 changes: 59 additions & 62 deletions formats/cvat_format.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,94 @@
import os
import cv2
import zipfile
from typing import List
from formats.base_format import BaseFormat


class CVATFormat(BaseFormat):
"""
Class to handle the CVAT format for image annotations.
Attributes:
output_dir (str): Base directory for all output.
Handles the CVAT format for image annotations. This class manages the creation of necessary directories,
the writing of annotations into CVAT-compatible text files, and the organization of image data.
"""

def __init__(self, output_dir):
super().__init__(output_dir)
def __init__(self, output_dir: str, sahi_enabled: bool = False):
super().__init__(output_dir, sahi_enabled)
self.data_dir = os.path.join(output_dir, 'data')
self.image_dir = os.path.join(self.data_dir, 'obj_train_data')
os.makedirs(self.image_dir, exist_ok=True)

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
def save_annotations(self, frame, frame_path: str, frame_filename: str, results, supported_classes: List[str]):
"""
Saves annotations and images in CVAT-compatible format directly in obj_train_data.
Saves annotations and frames in a format compatible with CVAT.
"""
img_dimensions = frame.shape[:2]
annotations = self.process_results(frame, results, img_dimensions)
frame_filename_png = frame_filename.replace('.jpg', '.png')
image_path = os.path.join(self.image_dir, frame_filename_png)
cv2.imwrite(image_path, frame) # Save the frame image
cv2.imwrite(image_path, frame)
self.write_annotations(frame_filename_png, annotations)
self.create_metadata_files(supported_classes)

annotation_filename = frame_filename_png.replace('.png', '.txt')
def write_annotations(self, frame_filename: str, annotations: List[str]):
"""
Writes annotations to a text file associated with each frame image.
"""
annotation_filename = frame_filename.replace('.png', '.txt')
annotation_path = os.path.join(self.image_dir, annotation_filename)
try:
with open(annotation_path, 'w') as file:
for annotation in annotations:
file.write(annotation + "\n")
except IOError as e:
print(f"Error writing annotation file {annotation_path}: {str(e)}")

with open(annotation_path, 'w') as file:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
if box.xyxy.dim() == 2 and box.xyxy.shape[0] == 1:
class_id = int(box.cls[0])
xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
x_center = ((xmin + xmax) / 2) / frame.shape[1]
y_center = ((ymin + ymax) / 2) / frame.shape[0]
width = (xmax - xmin) / frame.shape[1]
height = (ymax - ymin) / frame.shape[0]
file.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

# After saving all annotations, update metadata files
self.create_metadata_files(supported_classes)

def create_metadata_files(self, supported_classes):
def create_metadata_files(self, supported_classes: List[str]):
"""
Creates necessary metadata files for CVAT training setup.
Creates necessary metadata files for a CVAT training setup, including class names and training configurations.
"""
obj_names_path = os.path.join(self.data_dir, 'obj.names')
obj_data_path = os.path.join(self.data_dir, 'obj.data')
train_txt_path = os.path.join(self.data_dir, 'train.txt')

# Create obj.names file
with open(obj_names_path, 'w') as f:
for cls in supported_classes:
f.write(f"{cls}\n")
try:
with open(obj_names_path, 'w') as f:
for cls in supported_classes:
f.write(f"{cls}\n")

# Create obj.data file
with open(obj_data_path, 'w') as f:
f.write("classes = {}\n".format(len(supported_classes)))
f.write("train = data/train.txt\n")
f.write("names = data/obj.names\n")
f.write("backup = backup/\n")
with open(obj_data_path, 'w') as f:
f.write("classes = {}\n".format(len(supported_classes)))
f.write("train = data/train.txt\n")
f.write("names = data/obj.names\n")
f.write("backup = backup/\n")

# Create train.txt file listing all image files
with open(train_txt_path, 'w') as f:
for image_file in os.listdir(self.image_dir):
if image_file.endswith('.png'):
f.write(f"data/obj_train_data/{image_file}\n")

def ensure_directories(self):
"""Ensures all directories are created and ready for use."""
super().ensure_directories() # Ensures base directories are created
with open(train_txt_path, 'w') as f:
for image_file in os.listdir(self.image_dir):
if image_file.endswith('.png'):
f.write(f"data/obj_train_data/{image_file}\n")
except IOError as e:
print(f"Error writing metadata files: {str(e)}")

def zip_and_cleanup(self):
# Create a zip file and add all the data in the data directory to it.
"""
Zips the processed data for transfer or storage and cleans up the directory structure.
"""
zip_path = os.path.join(self.output_dir, 'cvat_data.zip')
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
try:
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(self.data_dir):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, self.data_dir))
for dir in dirs:
dir_path = os.path.join(root, dir)
zipf.write(dir_path, os.path.relpath(dir_path, self.data_dir))

# Cleanup
for root, dirs, files in os.walk(self.data_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, self.data_dir))
os.remove(os.path.join(root, file))
for dir in dirs:
dir_path = os.path.join(root, dir)
zipf.write(dir_path, os.path.relpath(dir_path, self.data_dir))

# Clean up the directory by removing all files first, then empty directories.
for root, dirs, files in os.walk(self.data_dir, topdown=False):
for file in files:
os.remove(os.path.join(root, file))
for dir in dirs:
os.rmdir(os.path.join(root, dir))

# Finally, remove the base data directory now that it should be empty.
os.rmdir(self.data_dir)
os.rmdir(os.path.join(root, dir))
os.rmdir(self.data_dir)
except Exception as e:
print(f"Error during zip or cleanup: {str(e)}")
Loading

0 comments on commit 8277bd3

Please sign in to comment.