-
-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from kadirnar/base_config
The base config of the torchyolo library has been improved.
- Loading branch information
Showing
19 changed files
with
228 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[tool.black] | ||
line-length = 120 | ||
|
||
[tool.isort] | ||
line_length = 120 | ||
profile = "black" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
black . --config pyproject.toml | ||
isort . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.1.0' | ||
__version__ = "0.1.0" |
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class YoloDetectionModel: | ||
def __init__( | ||
self, | ||
model_path: Optional[str] = None, | ||
config_path: Optional[str] = None, | ||
model: Optional[Any] = None, | ||
device: Optional[str] = None, | ||
confidence_threshold: float = 0.3, | ||
iou_threshold: float = 0.5, | ||
image_size: int = None, | ||
load_at_init: bool = True, | ||
): | ||
""" | ||
Init object detection/instance segmentation model. | ||
Args: | ||
model_path: str | ||
Path for the instance segmentation model weight | ||
config_path: str | ||
Path for the mmdetection instance segmentation model config file | ||
device: str | ||
Torch device, "cpu" or "cuda" | ||
iou_threshold: float | ||
All predictions with IoU < iou_threshold will be discarded | ||
confidence_threshold: float | ||
All predictions with score < confidence_threshold will be discarded | ||
image_size: int | ||
Inference input size. | ||
load_at_init: bool | ||
If True, automatically loads the model at initalization | ||
""" | ||
self.model_path = model_path | ||
self.config_path = config_path | ||
self.model = model | ||
self.device = device | ||
self.iou_threshold = iou_threshold | ||
self.confidence_threshold = confidence_threshold | ||
self.image_size = image_size | ||
|
||
# automatically load model if load_at_init is True | ||
if load_at_init: | ||
self.load_model() | ||
|
||
if self.device is None: | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
def load_model(self): | ||
""" | ||
This function should be implemented in a way that detection model | ||
should be initialized and set to self.model. | ||
(self.model_path, self.config_path, and self.device should be utilized) | ||
""" | ||
raise NotImplementedError() | ||
|
||
def predict(self, image: np.ndarray): | ||
""" | ||
This function should be implemented in a way that detection model | ||
should be initialized and set to self.model. | ||
(self.model_path, self.config_path, and self.device should be utilized) | ||
""" | ||
raise NotImplementedError() |
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
|
||
import yaml | ||
from easydict import EasyDict as edict | ||
|
||
|
||
class YamlParser(edict): | ||
""" | ||
This is yaml parser based on EasyDict. | ||
""" | ||
|
||
def __init__(self, cfg_dict=None, config_file=None): | ||
if cfg_dict is None: | ||
cfg_dict = {} | ||
|
||
if config_file is not None: | ||
assert os.path.isfile(config_file) | ||
with open(config_file, "r") as fo: | ||
yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader) | ||
cfg_dict.update(yaml_) | ||
|
||
super(YamlParser, self).__init__(cfg_dict) | ||
|
||
def merge_from_file(self, config_file): | ||
with open(config_file, "r") as fo: | ||
yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader) | ||
self.update(yaml_) | ||
|
||
def merge_from_dict(self, config_dict): | ||
self.update(config_dict) | ||
|
||
|
||
def get_config(config_file: str = None) -> YamlParser: | ||
""" | ||
This function is used to load config. | ||
Args: | ||
config_file: config file path | ||
Returns: | ||
config: config | ||
""" | ||
config = YamlParser(config_file=config_file) | ||
config.merge_from_file(config_file) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
def create_dir(_dir) -> str: | ||
""" | ||
Create directory if it doesn't exist | ||
Args: | ||
_dir: str | ||
""" | ||
import os | ||
|
||
if not os.path.exists(_dir): | ||
os.makedirs(_dir) | ||
|
||
return _dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class Colors: | ||
# Ultralytics color palette https://ultralytics.com/ | ||
def __init__(self): | ||
# hex = matplotlib.colors.TABLEAU_COLORS.values() | ||
hexs = ( | ||
"FF3838", | ||
"FF9D97", | ||
"FF701F", | ||
"FFB21D", | ||
"CFD231", | ||
"48F90A", | ||
"92CC17", | ||
"3DDB86", | ||
"1A9334", | ||
"00D4BB", | ||
"2C99A8", | ||
"00C2FF", | ||
"344593", | ||
"6473FF", | ||
"0018EC", | ||
"8438FF", | ||
"520085", | ||
"CB38FF", | ||
"FF95C8", | ||
"FF37C7", | ||
) | ||
self.palette = [self.hex2rgb(f"#{c}") for c in hexs] | ||
self.n = len(self.palette) | ||
|
||
def __call__(self, i, bgr=False): | ||
c = self.palette[int(i) % self.n] | ||
return (c[2], c[1], c[0]) if bgr else c | ||
|
||
@staticmethod | ||
def hex2rgb(h): # rgb order (PIL) | ||
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) | ||
|
||
|
||
def tracker_vis( | ||
track_id, | ||
label, | ||
frame, | ||
tracker_box, | ||
) -> np.ndarray: | ||
x, y, w, h = int(tracker_box[0]), int(tracker_box[1]), int(tracker_box[2]), int(tracker_box[3]) | ||
MIN_FONT_SCALE = 0.7 | ||
colors = Colors() # create instance for 'from yolov5.utils.plots import colors' | ||
color = colors(track_id % 10) | ||
txt_color = (0, 0, 0) if np.mean(color) > 0.5 else (255, 255, 255) | ||
font_scale = max(MIN_FONT_SCALE, 0.3 * (w + h) / 600) | ||
thickness = 2 | ||
txt_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0] | ||
cv2.rectangle(frame, (x, y), (w, h), color, thickness) # object box | ||
cv2.rectangle(frame, (x, y - txt_size[1]), (x + txt_size[0], y), color, -1) # object label box | ||
cv2.putText( | ||
frame, | ||
label, | ||
(x, y - 2), | ||
cv2.FONT_HERSHEY_SIMPLEX, | ||
font_scale, | ||
txt_color, | ||
thickness, | ||
cv2.LINE_AA, | ||
) | ||
|
||
|
||
def create_video_writer(video_path, output_path, fps=None) -> cv2.VideoWriter: | ||
""" | ||
This function is used to create video writer. | ||
Args: | ||
video_path: video path | ||
output_path: output path | ||
fps: fps | ||
Returns: | ||
video writer | ||
""" | ||
from pathlib import Path | ||
|
||
from file_utils import create_dir | ||
|
||
save_dir = create_dir(output_path) | ||
save_path = str(Path(save_dir) / Path(video_path).name) | ||
if fps is None: | ||
cap = cv2.VideoCapture(video_path) | ||
fps = cap.get(cv2.CAP_PROP_FPS) | ||
|
||
cap = cv2.VideoCapture(video_path) | ||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
size = (width, height) | ||
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | ||
videoWriter = cv2.VideoWriter(save_path, fourcc, fps, size) | ||
return videoWriter |