Skip to content

Commit

Permalink
Merge pull request #1 from kadirnar/base_config
Browse files Browse the repository at this point in the history
The base config of the torchyolo library has been improved.
  • Loading branch information
kadirnar authored Jan 3, 2023
2 parents cf5c802 + c8d1564 commit bc23816
Show file tree
Hide file tree
Showing 19 changed files with 228 additions and 2 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[tool.black]
line-length = 120

[tool.isort]
line_length = 120
profile = "black"
2 changes: 2 additions & 0 deletions script/code_format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
black . --config pyproject.toml
isort .
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def get_version():
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="machine-learning, deep-learning, pytorch, vision, yolov6,yolox, object-detection, yolov7, detector, yolov5"
keywords="machine-learning, deep-learning, pytorch, vision, yolov6,yolox, object-detection, yolov7, detector, yolov5",
)
2 changes: 1 addition & 1 deletion torchyolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = "0.1.0"
Empty file.
Empty file added torchyolo/configs/yolov5.yaml
Empty file.
Empty file added torchyolo/configs/yolov6.yaml
Empty file.
Empty file added torchyolo/configs/yolov7.yaml
Empty file.
Empty file added torchyolo/configs/yolox.yaml
Empty file.
Empty file added torchyolo/models/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions torchyolo/models/basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Any, Optional

import numpy as np
import torch


class YoloDetectionModel:
def __init__(
self,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
model: Optional[Any] = None,
device: Optional[str] = None,
confidence_threshold: float = 0.3,
iou_threshold: float = 0.5,
image_size: int = None,
load_at_init: bool = True,
):
"""
Init object detection/instance segmentation model.
Args:
model_path: str
Path for the instance segmentation model weight
config_path: str
Path for the mmdetection instance segmentation model config file
device: str
Torch device, "cpu" or "cuda"
iou_threshold: float
All predictions with IoU < iou_threshold will be discarded
confidence_threshold: float
All predictions with score < confidence_threshold will be discarded
image_size: int
Inference input size.
load_at_init: bool
If True, automatically loads the model at initalization
"""
self.model_path = model_path
self.config_path = config_path
self.model = model
self.device = device
self.iou_threshold = iou_threshold
self.confidence_threshold = confidence_threshold
self.image_size = image_size

# automatically load model if load_at_init is True
if load_at_init:
self.load_model()

if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(self):
"""
This function should be implemented in a way that detection model
should be initialized and set to self.model.
(self.model_path, self.config_path, and self.device should be utilized)
"""
raise NotImplementedError()

def predict(self, image: np.ndarray):
"""
This function should be implemented in a way that detection model
should be initialized and set to self.model.
(self.model_path, self.config_path, and self.device should be utilized)
"""
raise NotImplementedError()
Empty file added torchyolo/models/yolov5.py
Empty file.
Empty file added torchyolo/models/yolov6.py
Empty file.
Empty file added torchyolo/models/yolov7.py
Empty file.
Empty file added torchyolo/models/yolox.py
Empty file.
Empty file added torchyolo/utils/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions torchyolo/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

import yaml
from easydict import EasyDict as edict


class YamlParser(edict):
"""
This is yaml parser based on EasyDict.
"""

def __init__(self, cfg_dict=None, config_file=None):
if cfg_dict is None:
cfg_dict = {}

if config_file is not None:
assert os.path.isfile(config_file)
with open(config_file, "r") as fo:
yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader)
cfg_dict.update(yaml_)

super(YamlParser, self).__init__(cfg_dict)

def merge_from_file(self, config_file):
with open(config_file, "r") as fo:
yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader)
self.update(yaml_)

def merge_from_dict(self, config_dict):
self.update(config_dict)


def get_config(config_file: str = None) -> YamlParser:
"""
This function is used to load config.
Args:
config_file: config file path
Returns:
config: config
"""
config = YamlParser(config_file=config_file)
config.merge_from_file(config_file)
return config
12 changes: 12 additions & 0 deletions torchyolo/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def create_dir(_dir) -> str:
"""
Create directory if it doesn't exist
Args:
_dir: str
"""
import os

if not os.path.exists(_dir):
os.makedirs(_dir)

return _dir
97 changes: 97 additions & 0 deletions torchyolo/utils/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import cv2
import numpy as np


class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hexs = (
"FF3838",
"FF9D97",
"FF701F",
"FFB21D",
"CFD231",
"48F90A",
"92CC17",
"3DDB86",
"1A9334",
"00D4BB",
"2C99A8",
"00C2FF",
"344593",
"6473FF",
"0018EC",
"8438FF",
"520085",
"CB38FF",
"FF95C8",
"FF37C7",
)
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
self.n = len(self.palette)

def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c

@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))


def tracker_vis(
track_id,
label,
frame,
tracker_box,
) -> np.ndarray:
x, y, w, h = int(tracker_box[0]), int(tracker_box[1]), int(tracker_box[2]), int(tracker_box[3])
MIN_FONT_SCALE = 0.7
colors = Colors() # create instance for 'from yolov5.utils.plots import colors'
color = colors(track_id % 10)
txt_color = (0, 0, 0) if np.mean(color) > 0.5 else (255, 255, 255)
font_scale = max(MIN_FONT_SCALE, 0.3 * (w + h) / 600)
thickness = 2
txt_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0]
cv2.rectangle(frame, (x, y), (w, h), color, thickness) # object box
cv2.rectangle(frame, (x, y - txt_size[1]), (x + txt_size[0], y), color, -1) # object label box
cv2.putText(
frame,
label,
(x, y - 2),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
txt_color,
thickness,
cv2.LINE_AA,
)


def create_video_writer(video_path, output_path, fps=None) -> cv2.VideoWriter:
"""
This function is used to create video writer.
Args:
video_path: video path
output_path: output path
fps: fps
Returns:
video writer
"""
from pathlib import Path

from file_utils import create_dir

save_dir = create_dir(output_path)
save_path = str(Path(save_dir) / Path(video_path).name)
if fps is None:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)

cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
size = (width, height)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
videoWriter = cv2.VideoWriter(save_path, fourcc, fps, size)
return videoWriter

0 comments on commit bc23816

Please sign in to comment.