A Pytorch implementation of Discriminative feature generation for classification of imbalanced data described in the paper:
- Sungho Suh, Paul Lukowicz, and Yong Oh Lee, "Discriminative feature generation for classification of imbalanced data", Pattern Recognition, 2022. [Pattern Recognition] [arXiv]
Abstract
The data imbalance problem is frequently bottleneck of the neural network performance in classification. In this paper, we propose a novel supervised discriminative feature generation method (DFG) for minority class dataset. DFG is based on the modified structure of Generative Adversarial Network consisting of four independent networks: generator, discriminator, feature extractor, and classifier. To augment the selected discriminative features of minority class data by adopting attention mechanism, the generator for class-imbalanced target task is trained while feature extractor and classifier are regularized with the pre-trained ones from large source data. The experimental results show that the generator of DFG enhances the augmentation of label-preserved and diverse features, and classification results are significantly improved on the target task.
The performance of each model
LeNet-5 (EMNIST) | VGG-16 (CIFAR10) | ResNet-50 (ImageNet) | |||||
Data Set | SVHN | F-MNIST | STL-10 | CINIC-10 | CALTECH-256 | FOOD-101 | |
Imbalance ratio (IR) | 10:1 | 40:1 | 1:1 | 10:1 | 1:1 | 5:1 | |
ORIGINAL | 76.57 ± 1.65 | 77.13 ± 3.24 | 66.73 ± 0.67 | 58.14 ± 3.42 | 43.30 ± 0.34 | 30.17 ± 1.72 | |
FINE-TUNING | 76.66 ± 0.65 | 75.39 ± 3.54 | 72.89 ± 0.38 | 64.42 ± 0.81 | 81.99 ± 0.12 | 68.99 ± 0.42 | |
DELTA | 78.47 ± 0.57 | 78.91 ± 2.27 | 79.41 ± 0.27 | 68.89 ± 0.18 | 85.33 ± 0.15 | 72.03 ± 0.35 | |
DIFA + CMWGAN | 76.64 ± 1.32 | 78.27 ± 2.32 | 80.08 ± 0.19 | 71.82 ± 0.35 | 82.23 ± 0.32 | 71.71 ± 0.08 | |
OURS(DFG) | 80.81 ± 0.25 | 82.65 ± 0.28 | 81.09 ± 0.17 | 72.09 ± 0.20 | 86.29 ± 0.19 | 76.00 ± 0.36 |
- Linux (Ubuntu)
- Python >= 3.6
- NVIDIA GPU + CUDA CuDNN
- Clone this repo:
git clone https://github.com/opensuh/DFG/
- Install PyTorch
- For pip users, please type the command
pip install -r requirements.txt
. - For Conda users, you can create a new Conda environment using
conda env create -f environment.yml
.
- For pip users, please type the command
./scripts/lenet_svhn_train.sh
./scripts/lenet_fmnist_train.sh
./scripts/vgg_stl_train.sh
./scripts/vgg_cinic_train.sh
./scripts/resnet_caltech_train.sh
./scripts/resnet_food_train.sh
- Evaluate the model (our pre-trained models are in ./pretrained_model)
- We plan to upload the pre-trained models on our Github page.
./scripts/lenet_svhn_eval.sh
./scripts/lenet_fmnist_eval.sh
./scripts/vgg_stl_eval.sh
./scripts/vgg_cinic_eval.sh
./scripts/resnet_caltech_eval.sh
./scripts/resnet_food_eval.sh