Skip to content

Commit

Permalink
feat: added support for auto-detecting detection model types from hug…
Browse files Browse the repository at this point in the history
…gingface and loading models from a directory. If models do not have the model type encoded in the name, e.g. yolov5 the --model-type yolov5 must be used
  • Loading branch information
danellecline committed Jan 11, 2025
1 parent 248b789 commit 3ea7612
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 104 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ Detection
Detection can be done with a fine-grained saliency-based detection model, and/or one the following models run with the SAHI algorithm.
Both detections algorithms are run by default and combined to produce the final detections.

| Model | Description |
|-------------------------------|--------------------------------------------------------------------|
| yolov8s | YOLOv8s model from Ultralytics |
| hustvl/yolos-small | YOLOS model a Vision Transformer (ViT) |
| hustvl/yolos-tiny | YOLOS model a Vision Transformer (ViT) |
| MBARI/megamidwater (default) | MBARI midwater YOLOv5x for general detection in midwater images |
| MBARI/uav-yolov5 | MBARI UAV YOLOv5x for general detection in UAV images |
| MBARI/yolov5x6-uavs-oneclass | MBARI UAV YOLOv5x for general detection in UAV images single class |
| FathomNet/MBARI-315k-yolov5 | MBARI YOLOv5x for general detection in benthic images |
| Model | Description |
|----------------------------------|--------------------------------------------------------------------|
| yolov8s | YOLOv8s model from Ultralytics |
| hustvl/yolos-small | YOLOS model a Vision Transformer (ViT) |
| hustvl/yolos-tiny | YOLOS model a Vision Transformer (ViT) |
| MBARI-org/megamidwater (default) | MBARI midwater YOLOv5x for general detection in midwater images |
| MBARI-org/uav-yolov5 | MBARI UAV YOLOv5x for general detection in UAV images |
| MBARI-org/yolov5x6-uavs-oneclass | MBARI UAV YOLOv5x for general detection in UAV images single class |
| FathomNet/MBARI-315k-yolov5 | MBARI YOLOv5x for general detection in benthic images |


To skip saliency detection, use the --skip-saliency option.
Expand Down
101 changes: 6 additions & 95 deletions sdcat/detect/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from sdcat import common_args
from sdcat.config import config as cfg
from sdcat.config.config import default_config_ini
from sdcat.detect.model_util import create_model
from sdcat.detect.sahi_detector import run_sahi_detect_bulk, run_sahi_detect
from sdcat.detect.saliency_detector import run_saliency_detect, run_saliency_detect_bulk
from sdcat.logger import exception, info, warn, create_logger_file

default_model = 'MBARI/megamidwater'
default_model = 'MBARI-org/megamidwater'


@click.command('detect',
Expand All @@ -38,16 +39,17 @@
@click.option('--conf', default=0.1, help='Confidence threshold.')
@click.option('--scale-percent', default=80, help='Scaling factor to rescale the images before processing.')
@click.option('--model', default=default_model, help=f'Model to use. Defaults to {default_model}')
@click.option('--model-type', help=f'Type of model, e.g. yolov5, yolov8. Defaults to auto-detect.')
@click.option('--slice-size-width', help='Slice width size, leave blank for auto slicing')
@click.option('--slice-size-height', help='Slice height size, leave blank for auto slicing')
@click.option('--postprocess-match-metric', default='IOS', help='Postprocess match metric for NMS. postprocess_match_metric IOU for intersection over union, IOS for intersection over smaller area.')
@click.option('--overlap-width-ratio', default=0.4, help='Overlap width ratio for NMS')
@click.option('--overlap-height-ratio', default=0.4, help='Overlap height ratio for NMS')
@click.option('--clahe', is_flag=True, help='Run the CLAHE algorithm to contrast enhance before detection useful images with non-uniform lighting')

def run_detect(show: bool, image_dir: str, save_dir: str, model: str,
def run_detect(show: bool, image_dir: str, save_dir: str, model: str, model_type:str,
slice_size_width: int, slice_size_height: int, scale_percent: int,
postprocess_match_metric: str, overlap_width_ratio: float, overlap_height_ratio: float,
postprocess_match_metric: str, overlap_width_ratio: float, overlap_height_ratio: float,
device: str, conf: float, skip_sahi: bool, skip_saliency: bool, spec_remove: bool,
config_ini: str, clahe: bool, start_image: str, end_image: str):
config = cfg.Config(config_ini)
Expand All @@ -67,98 +69,7 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str,
create_logger_file('detect')

if not skip_sahi:
from sahi import AutoDetectionModel
if model == 'yolov8s':
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path='ultralyticsplus/yolov8s',
confidence_threshold=conf,
device=device,
)
elif model == 'yolov8x':
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path='yolov8x.pt',
confidence_threshold=conf,
device=device,
)
elif model == 'hustvl/yolos-small':
model_path = 'hustvl/yolos-small'
detection_model = AutoDetectionModel.from_pretrained(
model_type='huggingface',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'hustvl/yolos-tiny':
model_path = 'hustvl/yolos-tiny'
detection_model = AutoDetectionModel.from_pretrained(
model_type='huggingface',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'MBARI/megamidwater':
# Download model path
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="MBARI-org/megamidwater", filename="best.pt")
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov5',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'MBARI/uav-yolov5-30k':
# Download model path
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="MBARI-org/yolov5x6-uav-30k", filename="yolov5x6-uav-30k.pt")
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov5',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'MBARI/uav-yolov5-18k':
# Download model path
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="MBARI-org/yolov5-uav-18k", filename="yolov5x6-uav-18k.pt")
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov5',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'MBARI/yolov5x6-uavs-oneclass':
# Download model path
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="MBARI-org/yolov5x6-uavs-oneclass", filename="best_uavs_oneclass.pt")

detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov5',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
elif model == 'FathomNet/MBARI-315k-yolov5':
# Download model path
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="FathomNet/MBARI-315k-yolov5", filename="mbari_315k_yolov5.pt")
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov5',
model_path=model_path,
config_path=model_path,
confidence_threshold=conf,
device=device,
)
else:
exception(f'Unknown model: {model}')
return
detection_model = create_model(model, device, conf, model_type)

images_path = Path(image_dir)
base_path = Path(save_dir) / model
Expand Down
112 changes: 112 additions & 0 deletions sdcat/detect/model_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
from pathlib import Path

from sahi import AutoDetectionModel
from huggingface_hub import hf_hub_download

from sdcat.logger import err


def create_model(model, conf, device, model_type=None):
"""
Utility to determine the model type, model path, and create a detection model using SAHI.
Args:
model (str): The name of the model to use. Can be a predefined name or a local file path.
conf (float): Confidence threshold for the model.
device (str): The device to run the model on ('cpu', 'cuda', etc.).
model_type (str): The type of model to use (e.g., 'yolov5', 'huggingface'). If None, will auto-detect
Returns:
detection_model: An instance of the AutoDetectionModel.
"""

# Check if the provided model is a local file path
if os.path.exists(model):
if os.path.isdir(model):
dir_to_model_map = { "yolov5": "yolov5", "yolov8": "yolov8", "huggingface": "huggingface" }
model_path = [f for f in os.listdir(model) if f.endswith(".pt")]
if len(model_path) == 0:
err(f"No .pt file found in directory: {model}")
raise ValueError(f"No .pt file found in directory: {model}")
if model_type is None:
for k, v in dir_to_model_map.items():
if k in model:
model_type = v
break
if model_type is None:
err(f"Could not determine model type from directory name: {model}. Try the --model-type option, e.g., --model-type yolov5")
raise ValueError(f"Could not determine model type from directory name: {model}. Try the --model-type option, e.g., --model-type yolov5")
detection_model = AutoDetectionModel.from_pretrained(
model_type=model_type,
model_path=Path(model) / model_path[0],
confidence_threshold=conf,
device=device,
)
return detection_model
else:
raise ValueError(f"Model path is not a directory: {model}")

# Predefined model mapping
model_map = {
'yolov8s': {
'model_type': 'yolov8',
'model_path': 'ultralyticsplus/yolov8s'
},
'yolov8x': {
'model_type': 'yolov8',
'model_path': 'yolov8x.pt'
},
'hustvl/yolos-small': {
'model_type': 'huggingface',
'model_path': 'hustvl/yolos-small',
'config_path': 'hustvl/yolos-small'
},
'hustvl/yolos-tiny': {
'model_type': 'huggingface',
'model_path': 'hustvl/yolos-tiny',
'config_path': 'hustvl/yolos-tiny'
},
'MBARI-org/megamidwater': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/megamidwater", "best.pt")
},
'MBARI-org/uav-yolov5-30k': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/yolov5x6-uav-30k", "yolov5x6-uav-30k.pt")
},
'MBARI-org/uav-yolov5-18k': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/yolov5-uav-18k", "yolov5x6-uav-18k.pt")
},
'MBARI-org/yolov5x6-uavs-oneclass': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("MBARI-org/yolov5x6-uavs-oneclass", "best_uavs_oneclass.pt")
},
'FathomNet/MBARI-315k-yolov5': {
'model_type': 'yolov5',
'model_path': lambda: hf_hub_download("FathomNet/MBARI-315k-yolov5", "mbari_315k_yolov5.pt")
}
}

if model not in model_map:
raise ValueError(f"Unknown model: {model}")

model_info = model_map[model]
model_type = model_info['model_type']
model_path = model_info['model_path']

if callable(model_path): # If the path is a function (e.g., requires download)
model_path = model_path()

config_path = model_info.get('config_path', None)

detection_model = AutoDetectionModel.from_pretrained(
model_type=model_type,
model_path=model_path,
config_path=config_path,
confidence_threshold=conf,
device=device,
)

return detection_model

0 comments on commit 3ea7612

Please sign in to comment.