diff --git a/app/extractor.py b/app/extractor.py index e94febd..f1cefc1 100644 --- a/app/extractor.py +++ b/app/extractor.py @@ -18,17 +18,23 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class self.video_path = video_path # Ensure this is a string representing the path to the video file. self.frame_rate = frame_rate self.output_dir = output_dir - self.vision_model = self.get_given_model(model_path, model_types) + self.class_config_path = class_config_path self.output_format = output_format self.transformations = transformations + self.supported_classes_names = self.load_classes_names(self.class_config_path) self.supported_classes_ids = self.load_classes_ids(self.class_config_path) + self.supported_classes_map = self.load_classes_category_map(self.class_config_path) + + self.vision_model = self.get_given_model(model_path, model_types) + self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640))) # Only initialize SahiUtils if SAHI is enabled if sahi_config: - self.sahi_utils = SahiUtils(self.config.debug, os.path.join('models', model_path), **sahi_config) + self.sahi_utils = SahiUtils(self.config.debug, self.supported_classes_map, + self.vision_model, **sahi_config) else: self.sahi_utils = None @@ -41,11 +47,14 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class def get_given_model(self, model_path, types): try: if types == "RTDETR": - return RTDETR(os.path.join('models', model_path)) + model = RTDETR(os.path.join('models', model_path)) + return model elif types == "YOLO": - return YOLO(os.path.join('models', model_path)) + model = YOLO(os.path.join('models', model_path)) + return model elif types == "NAS": - return NAS(os.path.join('models', model_path)) + model = NAS(os.path.join('models', model_path)) + return model except Exception as e: raise ValueError(f"Model architecture and Model not Matching: {str(e)}") @@ -61,7 +70,7 @@ def load_classes_names(self, config_path): def load_classes_ids(self, config_path): """ - Load classes from a YAML configuration file. + Load classes id from a YAML configuration file. """ if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file not found at {config_path}") @@ -69,6 +78,19 @@ def load_classes_ids(self, config_path): class_data = yaml.safe_load(file) return [cls['id'] for cls in class_data['classes']] + def load_classes_category_map(self, config_path): + """ + Load a mapping of class names to class ids from a YAML configuration file. + Returns a dictionary where the key is the class name (string) and the value is the class id (string). + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found at {config_path}") + with open(config_path, 'r') as file: + class_data = yaml.safe_load(file) + + # Create a dictionary with 'name' as key and 'id' as value, both converted to string + return {str(cls['id']): str(cls['name']) for cls in class_data['classes']} + def extract_frames(self, model_confidence): cap = cv2.VideoCapture(self.video_path) if not cap.isOpened(): @@ -98,10 +120,12 @@ def extract_frames(self, model_confidence): results = self.sahi_utils.perform_sliced_inference(transformed_image) else: if self.config.debug: - results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False) + results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False, + classes=self.supported_classes_ids) # will add image show later time else: - results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False) + results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False, + classes=self.supported_classes_ids) # print(results) diff --git a/app/main.py b/app/main.py index c3311aa..d9fc3f8 100644 --- a/app/main.py +++ b/app/main.py @@ -56,7 +56,15 @@ def continue_ui(self): self.model_types = st.selectbox("Choose Model Types:", ("YOLO", "RTDETR", "NAS")) self.sahi_enabled = st.sidebar.checkbox("Enable SAHI", value=self.config.sahi_enabled) if self.sahi_enabled: - self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", "yolov9", "yolov10"]) + self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", + "rtdetr", + "yolonas", + "torchvision", + "huggingface", + "detectron2", + "mmdet", + ] + ) self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu"]) self.config.sahi_slice_size = st.sidebar.slider("SAHI slice size:", 128, 512, (256, 256)) self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, (0.2, 0.2)) @@ -108,30 +116,55 @@ def process_cloud_storage_video(self): # Proceed to run the extraction process self.run_extraction(video_path, unique_filename) + # def run_extraction(self, video_path, unique_filename): + # class_config_path = os.path.join(self.config.object_class_directory, self.class_config_selection) + # specific_output_dir = os.path.join(self.config.output_directory, unique_filename) + # os.makedirs(specific_output_dir, exist_ok=True) + # output_format_instance = self.format_options[self.format_selection](output_dir=specific_output_dir, + # sahi_enabled=self.sahi_enabled) + # try: + # extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir, + # self.model_selection, class_config_path, output_format_instance, + # self.transformations, self.model_types, self.sahi_config) + # extractor.extract_frames(self.model_confidence) + # if self.format_selection == "CVAT": + # output_format_instance.zip_and_cleanup() + # if self.storage_option == 'Object Storage': + # self.upload_outputs(specific_output_dir) + # + # # Clean up: Remove the temporary video file after processing + # if os.path.exists(video_path): + # os.remove(video_path) + # print(f"Deleted temporary video file: {video_path}") + # + # st.success('Extraction Completed!') + # except Exception as e: + # st.error(f"An error occurred during frame extraction: {str(e)}") + def run_extraction(self, video_path, unique_filename): class_config_path = os.path.join(self.config.object_class_directory, self.class_config_selection) specific_output_dir = os.path.join(self.config.output_directory, unique_filename) os.makedirs(specific_output_dir, exist_ok=True) output_format_instance = self.format_options[self.format_selection](output_dir=specific_output_dir, sahi_enabled=self.sahi_enabled) - try: - extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir, - self.model_selection, class_config_path, output_format_instance, - self.transformations, self.model_types, self.sahi_config) - extractor.extract_frames(self.model_confidence) - if self.format_selection == "CVAT": - output_format_instance.zip_and_cleanup() - if self.storage_option == 'Object Storage': - self.upload_outputs(specific_output_dir) - - # Clean up: Remove the temporary video file after processing - if os.path.exists(video_path): - os.remove(video_path) - print(f"Deleted temporary video file: {video_path}") - - st.success('Extraction Completed!') - except Exception as e: - st.error(f"An error occurred during frame extraction: {str(e)}") + + extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir, + self.model_selection, class_config_path, output_format_instance, + self.transformations, self.model_types, self.sahi_config) + extractor.extract_frames(self.model_confidence) + + if self.format_selection == "CVAT": + output_format_instance.zip_and_cleanup() + + if self.storage_option == 'Object Storage': + self.upload_outputs(specific_output_dir) + + # Clean up: Remove the temporary video file after processing + if os.path.exists(video_path): + os.remove(video_path) + print(f"Deleted temporary video file: {video_path}") + + st.success('Extraction Completed!') def upload_outputs(self, directory): """ diff --git a/utils/sahi_utils.py b/utils/sahi_utils.py index 675c23e..7f58873 100644 --- a/utils/sahi_utils.py +++ b/utils/sahi_utils.py @@ -11,9 +11,15 @@ class SahiUtils: - def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), + def __init__(self, debug, + supported_classes_map, + model_path, + model_type='yolov8', + device='cpu', + slice_size=(256, 256), overlap_ratio=(0.2, 0.2)): self.debug = debug + self.supported_classes_map = supported_classes_map self.device = device # Can be 'cpu' or 'cuda:0' for GPU self.model_type = model_type self.model = self.load_model(model_path) @@ -23,11 +29,14 @@ def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_s def load_model(self, model_path): """Loads a detection model based on the specified type and path.""" + # print(self.supported_classes_map) detection_model = AutoDetectionModel.from_pretrained( model_type=self.model_type, - model_path=model_path, + model=model_path, confidence_threshold=0.1, device=self.device, + # category_mapping=self.supported_classes_map, + # category_remapping=self.supported_classes_map, ) return detection_model @@ -58,6 +67,7 @@ def perform_sliced_inference(self, image): slice_width=self.slice_size[1], overlap_height_ratio=self.overlap_ratio[0], overlap_width_ratio=self.overlap_ratio[1], + postprocess_class_agnostic=True, verbose=False ) if self.debug: