forked from foamliu/InsightFace-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_gen.py
148 lines (117 loc) · 4.06 KB
/
data_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import pickle
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import IMG_DIR
from config import pickle_file
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import re
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
class ArcFaceDataset(Dataset):
def __init__(self, split):
with open(pickle_file, 'rb') as file:
data = pickle.load(file)
self.split = split
self.samples = data
self.transformer = data_transforms['train']
def __getitem__(self, i):
sample = self.samples[i]
filename = sample['img']
label = sample['label']
filename = os.path.join(IMG_DIR, filename)
img = Image.open(filename)
img = self.transformer(img)
return img, label
def __len__(self):
return len(self.samples)
import os
import pickle
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import IMG_DIR
from config import pickle_file
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
DEFAULT_ENCODING = 'utf-8'
def ustr(x):
'''py2/py3 unicode helper'''
if sys.version_info < (3, 0, 0):
from PyQt4.QtCore import QString
if type(x) == str:
return x.decode(DEFAULT_ENCODING)
if type(x) == QString:
#https://blog.csdn.net/friendan/article/details/51088476
#https://blog.csdn.net/xxm524/article/details/74937308
return unicode(x.toUtf8(), DEFAULT_ENCODING, 'ignore')
return x
else:
return x
def natural_sort(list, key=lambda s: s):
"""
Sort the list into natural alphanumeric order.
"""
def get_alphanum_key_func(key):
convert = lambda text: int(text) if text.isdigit() else text
return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
sort_key = get_alphanum_key_func(key)
list.sort(key=sort_key)
def scanAllImages(folderPath):
extensions = ['.%s' % fmt.data().decode("ascii").lower() for fmt in QImageReader.supportedImageFormats()]
images = []
for root, dirs, files in os.walk(folderPath):
for file in files:
if file.lower().endswith(tuple(extensions)):
relativePath = os.path.join(root, file)
path = ustr(os.path.abspath(relativePath))
images.append(path)
natural_sort(images, key=lambda x: x.lower())
return images
class ArcFaceDataset(Dataset):
def __init__(self, split):
# with open(pickle_file, 'rb') as file:
# data = pickle.load(file)
self.split = split
# self.samples = data
self.transformer = data_transforms['train']
self.images_path = scanAllImages(IMG_DIR)
def __getitem__(self, i):
path = self.images_path[i]
# filename = sample['img']
filename = path
label = int(os.path.basename(os.path.dirname(path)))
# filename = os.path.join(IMG_DIR, filename)
img = Image.open(filename)
img = self.transformer(img)
return img, label
def __len__(self):
return len(self.images_path)