Skip to content

Commit

Permalink
Replace Face Detection Library (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak authored Jun 26, 2024
1 parent 00a5f68 commit e66051e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
39 changes: 19 additions & 20 deletions datasets/vggface2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from torchvision import transforms

import cv2
import face_detection
import kornia.geometry.transform as GT
from batch_face import RetinaFace
from PIL import Image
from skimage import transform as trans
from tqdm import tqdm
Expand All @@ -38,7 +38,7 @@ class VGGFace2(Dataset):
VGGFace2 Dataset
"""
def __init__(self, root_dir, d_type, mode, transform=None,
teacher_transform=None, img_size=(112, 112)):
teacher_transform=None, img_size=(112, 112), args=None):

if d_type not in ('test', 'train'):
raise ValueError("d_type can only be set to 'test' or 'train'")
Expand All @@ -47,6 +47,7 @@ def __init__(self, root_dir, d_type, mode, transform=None,
raise ValueError("mode can only be set to 'detection', 'identification',"
"or 'identification_dr'")

self.device = args.device
self.root_dir = root_dir
self.d_type = d_type
self.transform = transform
Expand Down Expand Up @@ -99,8 +100,11 @@ def __extract_gt(self):
"""
Extracts the ground truth from the dataset
"""
detector = face_detection.build_detector("RetinaNetResNet50", confidence_threshold=.5,
nms_iou_threshold=.4)
if self.device == 'cuda':
detector = RetinaFace(gpu_id=torch.cuda.current_device(), network="resnet50")
else:
detector = RetinaFace(gpu_id=-1, network="resnet50")

img_paths = list(glob.glob(os.path.join(self.d_path + '/**/', '*.jpg'), recursive=True))
nf_number = 0
words_count = 0
Expand All @@ -111,22 +115,17 @@ def __extract_gt(self):
boxes = []
image = cv2.imread(jpg)

img_max = max(image.shape[0], image.shape[1])
if img_max > 1320:
continue
bboxes, lndmrks = detector.batched_detect_with_landmarks(np.expand_dims(image, 0))
bboxes = bboxes[0]
lndmrks = lndmrks[0]
faces = detector(image)

if (bboxes.shape[0] == 0) or (lndmrks.shape[0] == 0):
if len(faces) == 0:
nf_number += 1
continue

for box in bboxes:
for face in faces:
box = face[0]
box = np.clip(box[:4], 0, None)
boxes.append(box)

lndmrks = lndmrks[0]
lndmrks = faces[0][1]

dir_name = os.path.dirname(jpg)
lbl = os.path.relpath(dir_name, self.d_path)
Expand Down Expand Up @@ -343,7 +342,7 @@ def VGGFace2_FaceID_get_datasets(data, load_train=True, load_test=True, img_size

train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='identification',
transform=train_transform, teacher_transform=teacher_transform,
img_size=img_size)
img_size=img_size, args=args)

print(f'Train dataset length: {len(train_dataset)}\n')
else:
Expand All @@ -355,7 +354,7 @@ def VGGFace2_FaceID_get_datasets(data, load_train=True, load_test=True, img_size

test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='identification',
transform=test_transform, teacher_transform=teacher_transform,
img_size=img_size)
img_size=img_size, args=args)

print(f'Test dataset length: {len(test_dataset)}\n')
else:
Expand All @@ -378,7 +377,7 @@ def VGGFace2_FaceID_dr_get_datasets(data, load_train=True, load_test=True, img_s
if load_train:

train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='identification_dr',
transform=train_transform, img_size=img_size)
transform=train_transform, img_size=img_size, args=args)

print(f'Train dataset length: {len(train_dataset)}\n')
else:
Expand All @@ -389,7 +388,7 @@ def VGGFace2_FaceID_dr_get_datasets(data, load_train=True, load_test=True, img_s
ai8x.normalize(args=args)])

test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='identification_dr',
transform=test_transform, img_size=img_size)
transform=test_transform, img_size=img_size, args=args)

print(f'Test dataset length: {len(test_dataset)}\n')
else:
Expand All @@ -409,7 +408,7 @@ def VGGFace2_Facedet_get_datasets(data, load_train=True, load_test=True, img_siz
ai8x.normalize(args=args)])

train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='detection',
transform=train_transform, img_size=img_size)
transform=train_transform, img_size=img_size, args=args)

print(f'Train dataset length: {len(train_dataset)}\n')
else:
Expand All @@ -419,7 +418,7 @@ def VGGFace2_Facedet_get_datasets(data, load_train=True, load_test=True, img_siz
test_transform = transforms.Compose([ai8x.normalize(args=args)])

test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='detection',
transform=test_transform, img_size=img_size)
transform=test_transform, img_size=img_size, args=args)

print(f'Test dataset length: {len(test_dataset)}\n')
else:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Pillow>=7
PyYAML>=5.1.1
albumentations>=1.3.0
faiss-cpu==1.7.4
face-detection==0.2.2
batch-face>=1.4.0
h5py>=3.7.0
kornia==0.6.8
librosa>=0.7.2
Expand Down

0 comments on commit e66051e

Please sign in to comment.