-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathse_block.py
25 lines (22 loc) · 1.17 KB
/
se_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# --------------------------------------------------------
# Re-parameterizing Your Optimizers rather than Architectures (https://arxiv.org/abs/2205.15242)
# Github source: https://github.com/DingXiaoH/RepOptimizers
# Licensed under The MIT License [see LICENSE for details]
# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
# --------------------------------------------------------
import torch.nn as nn
import torch.nn.functional as F
class SEBlock(nn.Module):
def __init__(self, input_channels, internal_neurons):
super(SEBlock, self).__init__()
self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
self.input_channels = input_channels
def forward(self, inputs):
x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
x = self.down(x)
x = F.relu(x, inplace=True)
x = self.up(x)
x = F.sigmoid(x)
x = x.view(-1, self.input_channels, 1, 1)
return inputs * x