Skip to content

Commit

Permalink
Merge pull request #8 from shamspias/feat/rt-detr
Browse files Browse the repository at this point in the history
Feat/rt detr
  • Loading branch information
shamspias authored Sep 10, 2024
2 parents 496df74 + ba9ad5e commit 811cd31
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 29 deletions.
40 changes: 32 additions & 8 deletions app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class
self.video_path = video_path # Ensure this is a string representing the path to the video file.
self.frame_rate = frame_rate
self.output_dir = output_dir
self.vision_model = self.get_given_model(model_path, model_types)

self.class_config_path = class_config_path
self.output_format = output_format
self.transformations = transformations

self.supported_classes_names = self.load_classes_names(self.class_config_path)
self.supported_classes_ids = self.load_classes_ids(self.class_config_path)
self.supported_classes_map = self.load_classes_category_map(self.class_config_path)

self.vision_model = self.get_given_model(model_path, model_types)

self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640)))

# Only initialize SahiUtils if SAHI is enabled
if sahi_config:
self.sahi_utils = SahiUtils(self.config.debug, os.path.join('models', model_path), **sahi_config)
self.sahi_utils = SahiUtils(self.config.debug, self.supported_classes_map,
self.vision_model, **sahi_config)
else:
self.sahi_utils = None

Expand All @@ -41,11 +47,14 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class
def get_given_model(self, model_path, types):
try:
if types == "RTDETR":
return RTDETR(os.path.join('models', model_path))
model = RTDETR(os.path.join('models', model_path))
return model
elif types == "YOLO":
return YOLO(os.path.join('models', model_path))
model = YOLO(os.path.join('models', model_path))
return model
elif types == "NAS":
return NAS(os.path.join('models', model_path))
model = NAS(os.path.join('models', model_path))
return model
except Exception as e:
raise ValueError(f"Model architecture and Model not Matching: {str(e)}")

Expand All @@ -61,14 +70,27 @@ def load_classes_names(self, config_path):

def load_classes_ids(self, config_path):
"""
Load classes from a YAML configuration file.
Load classes id from a YAML configuration file.
"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found at {config_path}")
with open(config_path, 'r') as file:
class_data = yaml.safe_load(file)
return [cls['id'] for cls in class_data['classes']]

def load_classes_category_map(self, config_path):
"""
Load a mapping of class names to class ids from a YAML configuration file.
Returns a dictionary where the key is the class name (string) and the value is the class id (string).
"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found at {config_path}")
with open(config_path, 'r') as file:
class_data = yaml.safe_load(file)

# Create a dictionary with 'name' as key and 'id' as value, both converted to string
return {str(cls['id']): str(cls['name']) for cls in class_data['classes']}

def extract_frames(self, model_confidence):
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
Expand Down Expand Up @@ -98,10 +120,12 @@ def extract_frames(self, model_confidence):
results = self.sahi_utils.perform_sliced_inference(transformed_image)
else:
if self.config.debug:
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False)
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False,
classes=self.supported_classes_ids)
# will add image show later time
else:
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False)
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False,
classes=self.supported_classes_ids)

# print(results)

Expand Down
71 changes: 52 additions & 19 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ def continue_ui(self):
self.model_types = st.selectbox("Choose Model Types:", ("YOLO", "RTDETR", "NAS"))
self.sahi_enabled = st.sidebar.checkbox("Enable SAHI", value=self.config.sahi_enabled)
if self.sahi_enabled:
self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", "yolov9", "yolov10"])
self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8",
"rtdetr",
"yolonas",
"torchvision",
"huggingface",
"detectron2",
"mmdet",
]
)
self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu"])
self.config.sahi_slice_size = st.sidebar.slider("SAHI slice size:", 128, 512, (256, 256))
self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, (0.2, 0.2))
Expand Down Expand Up @@ -108,30 +116,55 @@ def process_cloud_storage_video(self):
# Proceed to run the extraction process
self.run_extraction(video_path, unique_filename)

# def run_extraction(self, video_path, unique_filename):
# class_config_path = os.path.join(self.config.object_class_directory, self.class_config_selection)
# specific_output_dir = os.path.join(self.config.output_directory, unique_filename)
# os.makedirs(specific_output_dir, exist_ok=True)
# output_format_instance = self.format_options[self.format_selection](output_dir=specific_output_dir,
# sahi_enabled=self.sahi_enabled)
# try:
# extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
# self.model_selection, class_config_path, output_format_instance,
# self.transformations, self.model_types, self.sahi_config)
# extractor.extract_frames(self.model_confidence)
# if self.format_selection == "CVAT":
# output_format_instance.zip_and_cleanup()
# if self.storage_option == 'Object Storage':
# self.upload_outputs(specific_output_dir)
#
# # Clean up: Remove the temporary video file after processing
# if os.path.exists(video_path):
# os.remove(video_path)
# print(f"Deleted temporary video file: {video_path}")
#
# st.success('Extraction Completed!')
# except Exception as e:
# st.error(f"An error occurred during frame extraction: {str(e)}")

def run_extraction(self, video_path, unique_filename):
class_config_path = os.path.join(self.config.object_class_directory, self.class_config_selection)
specific_output_dir = os.path.join(self.config.output_directory, unique_filename)
os.makedirs(specific_output_dir, exist_ok=True)
output_format_instance = self.format_options[self.format_selection](output_dir=specific_output_dir,
sahi_enabled=self.sahi_enabled)
try:
extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
self.model_selection, class_config_path, output_format_instance,
self.transformations, self.model_types, self.sahi_config)
extractor.extract_frames(self.model_confidence)
if self.format_selection == "CVAT":
output_format_instance.zip_and_cleanup()
if self.storage_option == 'Object Storage':
self.upload_outputs(specific_output_dir)

# Clean up: Remove the temporary video file after processing
if os.path.exists(video_path):
os.remove(video_path)
print(f"Deleted temporary video file: {video_path}")

st.success('Extraction Completed!')
except Exception as e:
st.error(f"An error occurred during frame extraction: {str(e)}")

extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
self.model_selection, class_config_path, output_format_instance,
self.transformations, self.model_types, self.sahi_config)
extractor.extract_frames(self.model_confidence)

if self.format_selection == "CVAT":
output_format_instance.zip_and_cleanup()

if self.storage_option == 'Object Storage':
self.upload_outputs(specific_output_dir)

# Clean up: Remove the temporary video file after processing
if os.path.exists(video_path):
os.remove(video_path)
print(f"Deleted temporary video file: {video_path}")

st.success('Extraction Completed!')

def upload_outputs(self, directory):
"""
Expand Down
14 changes: 12 additions & 2 deletions utils/sahi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@


class SahiUtils:
def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256),
def __init__(self, debug,
supported_classes_map,
model_path,
model_type='yolov8',
device='cpu',
slice_size=(256, 256),
overlap_ratio=(0.2, 0.2)):
self.debug = debug
self.supported_classes_map = supported_classes_map
self.device = device # Can be 'cpu' or 'cuda:0' for GPU
self.model_type = model_type
self.model = self.load_model(model_path)
Expand All @@ -23,11 +29,14 @@ def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_s

def load_model(self, model_path):
"""Loads a detection model based on the specified type and path."""
# print(self.supported_classes_map)
detection_model = AutoDetectionModel.from_pretrained(
model_type=self.model_type,
model_path=model_path,
model=model_path,
confidence_threshold=0.1,
device=self.device,
# category_mapping=self.supported_classes_map,
# category_remapping=self.supported_classes_map,
)
return detection_model

Expand Down Expand Up @@ -58,6 +67,7 @@ def perform_sliced_inference(self, image):
slice_width=self.slice_size[1],
overlap_height_ratio=self.overlap_ratio[0],
overlap_width_ratio=self.overlap_ratio[1],
postprocess_class_agnostic=True,
verbose=False
)
if self.debug:
Expand Down

0 comments on commit 811cd31

Please sign in to comment.