forked from buyizhiyou/NRVQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbrisque.py
125 lines (99 loc) · 3.63 KB
/
brisque.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math
import scipy.special
import numpy as np
import cv2
import scipy as sp
gamma_range = np.arange(0.2, 10, 0.001)
a = scipy.special.gamma(2.0/gamma_range)
a *= a
b = scipy.special.gamma(1.0/gamma_range)
c = scipy.special.gamma(3.0/gamma_range)
prec_gammas = a/(b*c)
def aggd_features(imdata):
# flatten imdata
imdata.shape = (len(imdata.flat),)
imdata2 = imdata*imdata
left_data = imdata2[imdata < 0]
right_data = imdata2[imdata >= 0]
left_mean_sqrt = 0
right_mean_sqrt = 0
if len(left_data) > 0:
left_mean_sqrt = np.sqrt(np.average(left_data))
if len(right_data) > 0:
right_mean_sqrt = np.sqrt(np.average(right_data))
if right_mean_sqrt != 0:
gamma_hat = left_mean_sqrt/right_mean_sqrt
else:
gamma_hat = np.inf
# solve r-hat norm
imdata2_mean = np.mean(imdata2)
if imdata2_mean != 0:
r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2))
else:
r_hat = np.inf
rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) *
(gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2))
# solve alpha by guessing values that minimize ro
pos = np.argmin((prec_gammas - rhat_norm)**2)
alpha = gamma_range[pos]
gam1 = scipy.special.gamma(1.0/alpha)
gam2 = scipy.special.gamma(2.0/alpha)
gam3 = scipy.special.gamma(3.0/alpha)
aggdratio = np.sqrt(gam1) / np.sqrt(gam3)
bl = aggdratio * left_mean_sqrt
br = aggdratio * right_mean_sqrt
# mean parameter
N = (br - bl)*(gam2 / gam1) # *aggdratio
return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt)
def ggd_features(imdata):
nr_gam = 1/prec_gammas
sigma_sq = np.var(imdata)
E = np.mean(np.abs(imdata))
rho = sigma_sq/E**2
pos = np.argmin(np.abs(nr_gam - rho))
return gamma_range[pos], sigma_sq
def paired_product(new_im):
shift1 = np.roll(new_im.copy(), 1, axis=1)
shift2 = np.roll(new_im.copy(), 1, axis=0)
shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1)
shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1)
H_img = shift1 * new_im
V_img = shift2 * new_im
D1_img = shift3 * new_im
D2_img = shift4 * new_im
return (H_img, V_img, D1_img, D2_img)
def calculate_mscn(dis_image):
dis_image = dis_image.astype(np.float32) # 类型转换十分重要
ux = cv2.GaussianBlur(dis_image, (7, 7), 7/6)
ux_sq = ux*ux
sigma = np.sqrt(np.abs(cv2.GaussianBlur(dis_image**2, (7, 7), 7/6)-ux_sq))
mscn = (dis_image-ux)/(1+sigma)
return mscn
def ggd_features(imdata):
nr_gam = 1/prec_gammas
sigma_sq = np.var(imdata)
E = np.mean(np.abs(imdata))
rho = sigma_sq/E**2
pos = np.argmin(np.abs(nr_gam - rho))
return gamma_range[pos], sigma_sq
def extract_brisque_feats(mscncoefs):
alpha_m, sigma_sq = ggd_features(mscncoefs.copy())
pps1, pps2, pps3, pps4 = paired_product(mscncoefs)
alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1)
alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2)
alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3)
alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4)
# print(alpha_m, alpha1)
return [
alpha_m, sigma_sq,
alpha1, N1, lsq1**2, rsq1**2, # (V)
alpha2, N2, lsq2**2, rsq2**2, # (H)
alpha3, N3, lsq3**2, rsq3**2, # (D1)
alpha4, N4, lsq4**2, rsq4**2, # (D2)
]
def brisque(im):
mscncoefs = calculate_mscn(im)
features1 = extract_brisque_feats(mscncoefs)
lowResolution = cv2.resize(im, (0, 0), fx=0.5, fy=0.5)
features2 = extract_brisque_feats(lowResolution)
return np.array(features1+features2)