-
Notifications
You must be signed in to change notification settings - Fork 0
/
ocr.py
184 lines (148 loc) · 4.2 KB
/
ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import logging
import os
import pickle
import cv2
import numpy as np
import requests
from PIL import Image
from skimage import feature as ft
from sklearn.externals import joblib
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
class OCR:
def __init__(self):
self.svm = SVC(kernel='rbf', random_state=0, gamma='auto', C=1.0, probability=True)
self.sc = StandardScaler()
self.score_ = None
def fit(self, X, y, test_size=0.3, sample_weight=None):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size,
random_state=0)
self.sc.fit(X_train)
X_train_std = self.sc.transform(X_train)
X_test_std = self.sc.transform(X_test)
self.svm.fit(X_train_std, y_train, sample_weight=sample_weight)
y_pred = self.svm.predict(X_test_std)
self.score_ = accuracy_score(y_test, y_pred)
def _transform(self, X):
return self.sc.transform(X)
def predict(self, X):
return self.svm.predict(self._transform(X))
def predict_proba(self, X):
return self.svm.predict_proba(self._transform(X))
def score(self):
return self.score_
def del_blur(img):
"""
去除验证码中的干扰元素
:param img: 必须为OpenCV的img对象
即使用cv2.imread('temp.png', 0)打开的图片0 代表以二值化方式打开
:return:
"""
# 双边滤波 去噪 效果不错
m_blur = cv2.bilateralFilter(img, 9, 75, 75)
# oust 滤波 二值化图片
ret, oust_img = cv2.threshold(m_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
fin = cv2.bilateralFilter(oust_img, 9, 75, 75)
return fin
def split_img(img):
"""
分割验证码
:param img:
:return:
"""
hs = 3
h = 14
w = 11
ws1 = 3
ws2 = 13
ws3 = 23
ws4 = 33
img1 = img[hs:hs + h, ws1:ws1 + w]
img2 = img[hs:hs + h, ws2:ws2 + w]
img3 = img[hs:hs + h, ws3:ws3 + w]
img4 = img[hs:hs + h, ws4:ws4 + w]
return img1, img2, img3, img4
def hog(img):
"""
提取图片hog特征(将图片信息降维 二维数组变一维数组)
:param img: 图片文件的地址或numpy.array对象
:return f: 图片hog特征数据(一维)
"""
if isinstance(img, str):
img = Image.open(img)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
else:
raise ValueError('the input img is not fill the bill! \n '
'it must be an image file path or numpy.ndarray')
return ft.hog(img, block_norm='L2-Hys', pixels_per_cell=(2, 2), cells_per_block=(2, 2))
def load_data(relaod=False):
"""
加载训练数据
:param relaod:
:return:
"""
X_train = []
y_train = []
x_path = 'x_train.pickle'
y_path = 'y_train.pickle'
x_exist = os.path.exists(x_path)
y_exist = os.path.exists(y_path)
if not relaod and x_exist and y_exist:
try:
with open(x_path, 'rb') as f:
X_train = pickle.load(f)
with open(y_path, 'rb') as f:
y_train = pickle.load(f)
except EOFError:
pass
finally:
logging.debug('load by pickle')
return X_train, y_train
for label in os.listdir(r'data'):
label_path = os.path.join(os.path.abspath('.'), 'data', label)
label_items = os.listdir(label_path)
for item_name in label_items:
item = os.path.join(label_path, item_name)
feature = hog(item)
X_train.append(feature)
y_train.append(label)
with open(x_path, 'wb+') as f:
pickle.dump(X_train, f)
with open(y_path, 'wb+') as f:
pickle.dump(y_train, f)
logging.debug('load by file')
return X_train, y_train
def load_ocr(reload=False):
"""
加载已经训练好的模型(如果有的话)
:param reload:
:return:
"""
ocr_path = 'wust.ocr'
if not reload and os.path.exists(ocr_path):
ocr = joblib.load(ocr_path)
else:
ocr = OCR()
X, y = load_data()
ocr.fit(X, y)
joblib.dump(ocr, ocr_path)
return ocr
def main():
ocr = load_ocr()
url = 'http://jwxt.wust.edu.cn/whkjdx/verifycode.servlet?0.12337475696465894'
con = requests.get(url).content
with open('temp.png', 'wb') as f:
f.write(con)
img = cv2.imread('temp.png', 0)
os.remove('temp.png')
img_arr = del_blur(img)
hog_arr = [hog(x) for x in split_img(img_arr)]
pred = ocr.predict(hog_arr)
name = ''.join(pred) + '.png'
with open(os.path.join('verifycode', name), 'wb') as f:
f.write(con)
if __name__ == '__main__':
main()