Guanwenjie Zou (邹广文杰)*
,
Liang Yao (姚亮)*
,
Fan Liu (刘凡) ✉
,
Chuanyi Zhang (张传一)
,
Xin Li (李鑫)
,
Ning Chen (陈宁)
,
Shengxiang Xu (徐圣翔)
,
Jun Zhou (周峻)
- 2024/12/21: The paper is accepted by ICASSP 2025!
- 2024/07/29: We propose an efficient structural pruning method for remote sensing image classification. Codes and models will be open-sourced at this repository.
Please Contact yaoliang@hhu.edu.cn
Our pruning method utiles the torch-pruning framework, which is compatible with both PyTorch 1.x and 2.x versions.
pip install torch-pruning
We leverage intermediate outputs from SENet and scaling factors from the BN layer to map channel importance into the attention space. During the post-pruning fine-tuning phase, we design a lateral inhibition loss function to emphasize difficult samples. Our method effectively addresses two key challenges in remote sensing model pruning: the lack of distinct channel importance and the prevalence of difficult samples.
loading resnet18 model with SENet
import torch
from resnet18_SE.py import resnet18_SE
model = resnet18_SE(class_num) # class_num is the number of types of datasets
training model on EuroSAT datasets
python train.py
extracting the intermediate outputs of SENet
python get_attention.py
deleting SENet upon resnet18 and updating BN layer parameters
python delete_SE_resnet18.py
pruning model
python pruning.py
fine-tuning pruned model with Adaptive Mining Loss function
from Adaptive_Mining_Loss import MyLoss
loss_fn = MyLoss(r1=1, r2=1)