Skip to content

Commit

Permalink
Merge pull request #2 from shamspias/feat/color_resize
Browse files Browse the repository at this point in the history
Feat/color resize
  • Loading branch information
shamspias authored Sep 4, 2024
2 parents 8373d1e + 6993465 commit 9f4ca81
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 26 deletions.
70 changes: 45 additions & 25 deletions app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,32 @@
import os
from ultralytics import YOLO
import yaml
from utils.image_processor import ImageProcessor # Make sure this is the correct import path


class VideoFrameExtractor:
"""
A class to extract frames from video at specified intervals and annotate them using YOLO model predictions.
Attributes:
video_path (str): Path to the video file.
frame_rate (float): Desired frame rate to extract images.
output_dir (str): Base directory to store extracted images and annotations.
model_path (str): Path to the YOLO model for object detection.
class_config_path (str): Path to the class configuration file.
output_format (object): Format handler for saving annotations.
Extracts frames from video at specified intervals, applies selected transformations,
and annotates them using YOLO model predictions.
"""

def __init__(self, video_path, frame_rate, output_dir, model_path, class_config_path, output_format):
def __init__(self, video_path, frame_rate, output_dir, model_path, class_config_path, output_format,
transformations):
self.video_path = video_path
self.frame_rate = frame_rate
self.output_dir = output_dir
self.yolo_model = YOLO(os.path.join('models', model_path))
self.output_format = output_format
self.supported_classes = self.load_classes(class_config_path)
self.transformations = transformations
self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640)))

def load_classes(self, config_path):
"""
Loads object classes from a YAML configuration file.
"""
with open(config_path, 'r') as file:
class_data = yaml.safe_load(file)
return [cls['name'] for cls in class_data['classes']]

def extract_frames(self, model_confidence):
"""
Extracts frames from the video file at the specified frame rate, annotates them using the YOLO model,
and saves using the specified format.
"""
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
raise FileNotFoundError(f"Unable to open video file: {self.video_path}")
Expand All @@ -52,18 +42,48 @@ def extract_frames(self, model_confidence):
break

if frame_count % frame_interval == 0:
frame_filename = f"{self._get_video_basename()}_image{frame_count}.jpg"
frame_path = os.path.join(self.output_dir, 'images', frame_filename)
cv2.imwrite(frame_path, frame)
results = self.yolo_model.predict(frame, conf=model_confidence)
self.output_format.save_annotations(frame, frame_path, frame_filename, results, self.supported_classes)
transformed_images = self.apply_transformations(frame)

for key, transformed_image in transformed_images.items():
if transformed_image.ndim == 2: # Check if the image is grayscale
transformed_image = cv2.cvtColor(transformed_image,
cv2.COLOR_GRAY2BGR) # Convert back to RGB format for consistency

frame_filename = f"{self._get_video_basename()}_image{frame_count}_{key}.jpg"
frame_path = os.path.join(self.output_dir, 'images', frame_filename)
cv2.imwrite(frame_path, transformed_image)
results = self.yolo_model.predict(transformed_image, conf=model_confidence)
self.output_format.save_annotations(transformed_image, frame_path, frame_filename, results,
self.supported_classes)

frame_count += 1

cap.release()

def apply_transformations(self, frame):
# Dictionary to hold transformed images
transformed_images = {}

# Apply resizing if selected
if 'resize' in self.transformations and self.transformations['resize']:
frame = self.image_processor.resize_image(frame)
transformed_images['resized'] = frame # Store resized image

# Apply grayscale transformation if selected
if 'grayscale' in self.transformations and self.transformations['grayscale']:
grayscale_image = self.image_processor.convert_to_grayscale(frame)
transformed_images['grayscale'] = grayscale_image # Store grayscale image

# Apply 90-degree rotation if selected
if 'rotate' in self.transformations and self.transformations['rotate']:
rotated_image = self.image_processor.rotate_image_90_degrees(frame)
transformed_images['rotated'] = rotated_image # Store rotated image

# If no transformations are selected, add the original image
if not transformed_images:
transformed_images['original'] = frame

return transformed_images

def _get_video_basename(self):
"""
Extracts the basename of the video file without its extension.
"""
return os.path.splitext(os.path.basename(self.video_path))[0]
14 changes: 13 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from extractor import VideoFrameExtractor
from formats.roboflow_format import RoboflowFormat
from formats.cvat_format import CVATFormat
from utils.image_processor import ImageProcessor # Import the ImageProcessor

# Import other formats if available

Expand All @@ -26,6 +27,17 @@
frame_rate = st.number_input("Frame rate", value=config.default_frame_rate)
model_confidence = st.number_input("Model Confidence", value=0.1)

# New fields for image transformation options
image_width = st.number_input("Image width", value=640)
image_height = st.number_input("Image height", value=640)

transformation_options = st.multiselect('Select image transformations:', ['Resize', 'Grayscale', 'Rotate 90 degrees'])
transformations = {
'resize': 'Resize' in transformation_options,
'grayscale': 'Grayscale' in transformation_options,
'rotate': 'Rotate 90 degrees' in transformation_options
}

# Allow users to choose the output format
format_options = {'Roboflow': RoboflowFormat, 'CVAT': CVATFormat} # Add more formats to this dictionary
format_selection = st.selectbox("Choose output format:", list(format_options.keys()))
Expand Down Expand Up @@ -58,7 +70,7 @@
# Extract frames using the VideoFrameExtractor with the chosen format
try:
extractor = VideoFrameExtractor(video_path, frame_rate, specific_output_dir, model_selection,
class_config_path, output_format_instance)
class_config_path, output_format_instance, transformations)
extractor.extract_frames(model_confidence)

if format_selection == "CVAT": # If CVAT export then it will save as zip format
Expand Down
Empty file added utils/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions utils/image_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import cv2


class ImageProcessor:
"""
A class for handling image transformations such as resizing, converting to grayscale, and adjusting RGB channels.
Attributes:
output_size (tuple): The desired output dimensions for resizing images (width, height).
"""

def __init__(self, output_size=(640, 640)):
"""
Initializes the ImageProcessor with a specified output size.
Parameters:
output_size (tuple): The width and height to which images should be resized, default is (640, 640).
"""
self.output_size = output_size

def resize_image(self, image):
"""
Resizes an image to the specified output size.
Parameters:
image (np.array): The image to resize.
Returns:
np.array: The resized image.
"""
resized_image = cv2.resize(image, self.output_size, interpolation=cv2.INTER_AREA)
assert resized_image.shape[0] == self.output_size[1] and resized_image.shape[1] == self.output_size[
0], "Resizing did not match expected dimensions."
return resized_image

def convert_to_grayscale(self, image):
"""
Converts an image to grayscale.
Parameters:
image (np.array): The image to convert.
Returns:
np.array: The grayscale image.
"""
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Input image is not in expected RGB format.")
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

def rotate_image_90_degrees(self, image):
"""
Rotates an image by 90 degrees clockwise.
Parameters:
image (np.array): The image to rotate.
Returns:
np.array: The rotated image.
"""
return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)

0 comments on commit 9f4ca81

Please sign in to comment.