Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv5 + Albumentations integration #3882

Merged
merged 10 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ pandas
# extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
# pycocotools>=2.0 # COCO mAP
# albumentations>=1.0.0
thop # FLOPs computation
30 changes: 29 additions & 1 deletion utils/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
# YOLOv5 image augmentation functions

import logging
import random

import cv2
import math
import numpy as np

from utils.general import segment2box, resample_segments
from utils.general import colorstr, segment2box, resample_segments, check_version
from utils.metrics import bbox_ioa


class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self):
self.transform = None
try:
import albumentations as A
check_version(A.__version__, '1.0.0') # version requirement

self.transform = A.Compose([
A.Blur(p=0.1),
A.MedianBlur(p=0.1),
A.ToGray(p=0.01)],
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
except ImportError: # package not installed, skip
pass
except Exception as e:
logging.info(colorstr('albumentations: ') + f'{e}')

def __call__(self, im, labels, p=1.0):
if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
return im, labels


def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation
if hgain or sgain or vgain:
Expand Down
40 changes: 21 additions & 19 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first
Expand Down Expand Up @@ -372,6 +372,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None

try:
f = [] # image files
Expand Down Expand Up @@ -539,42 +540,43 @@ def __getitem__(self, index):
if labels.size: # normalized xywh to pixel xyxy format
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

if self.augment:
# Augment imagespace
if not mosaic:
if self.augment:
img, labels = random_perspective(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'],
perspective=hyp['perspective'])

# Augment colorspace
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Apply cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

nL = len(labels) # number of labels
if nL:
nl = len(labels) # number of labels
if nl:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized

if self.augment:
# flip up-down
# Albumentations
img, labels = self.albumentations(img, labels)

# HSV color-space
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Flip up-down
if random.random() < hyp['flipud']:
img = np.flipud(img)
if nL:
if nl:
labels[:, 2] = 1 - labels[:, 2]

# flip left-right
# Flip left-right
if random.random() < hyp['fliplr']:
img = np.fliplr(img)
if nL:
if nl:
labels[:, 1] = 1 - labels[:, 1]

labels_out = torch.zeros((nL, 6))
if nL:
# Cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels)

# Convert
Expand Down
17 changes: 10 additions & 7 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import glob
import logging
import math
import os
import platform
import random
Expand All @@ -17,6 +16,7 @@
from subprocess import check_output

import cv2
import math
import numpy as np
import pandas as pd
import pkg_resources as pkg
Expand Down Expand Up @@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
print(f'{e}{err_msg}')


def check_python(minimum='3.6.2', required=True):
def check_python(minimum='3.6.2'):
# Check current python version vs. required python version
current = platform.python_version()
result = pkg.parse_version(current) >= pkg.parse_version(minimum)
if required:
assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
return result
check_version(platform.python_version(), minimum, name='Python ')


def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
# Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum)
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'


def check_requirements(requirements='requirements.txt', exclude=()):
Expand Down