-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
90 lines (71 loc) · 3.23 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from keras.layers import Activation, Conv3D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer
#from tensorflow.keras.layers import Input,Layer, Conv3D,Activation
#import tensorflow.keras.backend as K
tf.config.run_functions_eagerly(True)
class PAM(Layer):
def __init__(self,
gamma_initializer=tf.zeros_initializer(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(PAM, self).__init__(**kwargs)
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
self.gamma = self.add_weight(shape=(1, ),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
self.built = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, input):
input_shape = input.get_shape().as_list()
_, h, w,z, filters = input_shape
b = Conv3D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input)
c = Conv3D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input)
d = Conv3D(filters, 1, use_bias=False, kernel_initializer='he_normal')(input)
vec_b = K.reshape(b, (-1, h * w*z, filters // 8))
vec_cT = tf.transpose(K.reshape(c, (-1, h * w*z, filters // 8)), (0, 2, 1))
bcT = K.batch_dot(vec_b, vec_cT)
softmax_bcT = Activation('softmax')(bcT)
vec_d = K.reshape(d, (-1, h * w*z, filters))
bcTd = K.batch_dot(softmax_bcT, vec_d)
bcTd = K.reshape(bcTd, (-1, h, w,z, filters))
out = self.gamma*bcTd + input
return out
class CAM(Layer):
def __init__(self,
gamma_initializer=tf.zeros_initializer(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(CAM, self).__init__(**kwargs)
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
self.gamma = self.add_weight(shape=(1, ),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
self.built = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, input):
input_shape = input.get_shape().as_list()
_, h, w,z, filters = input_shape
vec_a = K.reshape(input, (-1, h * w*z, filters))
vec_aT = tf.transpose(vec_a, (0, 2, 1))
aTa = K.batch_dot(vec_aT, vec_a)
softmax_aTa = Activation('softmax')(aTa)
aaTa = K.batch_dot(vec_a, softmax_aTa)
aaTa = K.reshape(aaTa, (-1, h, w,z, filters))
out = self.gamma*aaTa + input
return out