-
Notifications
You must be signed in to change notification settings - Fork 0
/
spatial_attension_model.py
110 lines (86 loc) · 5 KB
/
spatial_attension_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda, Add, Activation, multiply, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import MeanIoU
kernel_initializer = 'he_uniform' #Try others if you want
from tensorflow.keras import backend as K
# Channel-wise attention mechanism
import tensorflow as tf
def channel_attention(feature_map):
# Compute the mean and max pooling along the channel dimension
mean_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=(1,2,3), keepdims=True))(feature_map)
max_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.max(x, axis=(1,2,3), keepdims=True))(feature_map)
# Concatenate the mean and max pooled features
concat = tf.keras.layers.Concatenate(axis=-1)([mean_pool, max_pool])
# Compute the channel attention weights using a 3D convolution
weights = tf.keras.layers.Conv3D(filters=1, kernel_size=(1,1,1), strides=(1,1,1), activation='softmax', use_bias=False)(concat)
# Apply the attention weights to the input feature map
attended_feature_map = tf.keras.layers.Multiply()([feature_map, weights])
return attended_feature_map
def improved_unet_model_with_attention(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):
#Build the model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
#s = Lambda(lambda x: x / 255)(inputs) #No need for this if we normalize our inputs beforehand
s = inputs
#Contraction path
c1 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(s)
c1 = BatchNormalization()(c1)
c1 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c1)
c1 = BatchNormalization()(c1)
p1 = MaxPooling3D((2, 2, 2))(c1)
p1 = channel_attention(p1)
c2 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(p1)
c2 = BatchNormalization()(c2)
c2 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c2)
c2 = BatchNormalization()(c2)
p2 = MaxPooling3D((2, 2, 2))(c2)
p2 = channel_attention(p2)
c3 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(p2)
c3 = BatchNormalization()(c3)
c3 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c3)
c3 = BatchNormalization()(c3)
p3 = MaxPooling3D((2, 2, 2))(c3)
p3 = channel_attention(p3)
c4 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(p3)
c4 = BatchNormalization()(c4)
c4 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c4)
c4 = BatchNormalization()(c4)
p4 = MaxPooling3D(pool_size=(2, 2, 2))(c4)
p4 = channel_attention(p4)
#Expansion path
u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(p4)
u6 = concatenate([u6, c4])
u6 = channel_attention(u6)
c6 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u6)
c6 = BatchNormalization()(c6)
c6 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c6)
c6 = BatchNormalization()(c6)
u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)
u7 = concatenate([u7, c3])
u7 = channel_attention(u7)
c7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u7)
c7 = BatchNormalization()(c7)
c7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c7)
c7 = BatchNormalization()(c7)
u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)
u8 = concatenate([u8, c2])
u8 = channel_attention(u8)
c8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u8)
c8 = BatchNormalization()(c8)
c8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c8)
c8 = BatchNormalization()(c8)
u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)
u9 = concatenate([u9, c1])
u9 = channel_attention(u9)
c9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u9)
c9 = BatchNormalization()(c9)
c9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c9)
c9 = BatchNormalization()(c9)
outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(c9)
model = Model(inputs=[inputs], outputs=[outputs])
model.summary()
return model
#Test if everything is working ok.
model = improved_unet_model_with_attention(128, 128, 128, 3, 4)
print(model.input_shape)
print(model.output_shape)