Skip to content

opensuh/DFG

Repository files navigation

Discriminative Feature Generation for Classification of Imbalanced Data in PyTorch

Figure1

A Pytorch implementation of Discriminative feature generation for classification of imbalanced data described in the paper:

  • Sungho Suh, Paul Lukowicz, and Yong Oh Lee, "Discriminative feature generation for classification of imbalanced data", Pattern Recognition, 2022. [Pattern Recognition] [arXiv]

Abstract

The data imbalance problem is frequently bottleneck of the neural network performance in classification. In this paper, we propose a novel supervised discriminative feature generation method (DFG) for minority class dataset. DFG is based on the modified structure of Generative Adversarial Network consisting of four independent networks: generator, discriminator, feature extractor, and classifier. To augment the selected discriminative features of minority class data by adopting attention mechanism, the generator for class-imbalanced target task is trained while feature extractor and classifier are regularized with the pre-trained ones from large source data. The experimental results show that the generator of DFG enhances the augmentation of label-preserved and diverse features, and classification results are significantly improved on the target task.

Models

The performance of each model

LeNet-5 (EMNIST) VGG-16 (CIFAR10) ResNet-50 (ImageNet)
Data Set SVHN F-MNIST STL-10 CINIC-10 CALTECH-256 FOOD-101
Imbalance ratio (IR) 10:1 40:1 1:1 10:1 1:1 5:1
ORIGINAL 76.57 ± 1.65 77.13 ± 3.24 66.73 ± 0.67 58.14 ± 3.42 43.30 ± 0.34 30.17 ± 1.72
FINE-TUNING 76.66 ± 0.65 75.39 ± 3.54 72.89 ± 0.38 64.42 ± 0.81 81.99 ± 0.12 68.99 ± 0.42
DELTA 78.47 ± 0.57 78.91 ± 2.27 79.41 ± 0.27 68.89 ± 0.18 85.33 ± 0.15 72.03 ± 0.35
DIFA + CMWGAN 76.64 ± 1.32 78.27 ± 2.32 80.08 ± 0.19 71.82 ± 0.35 82.23 ± 0.32 71.71 ± 0.08
OURS(DFG) 80.81 ± 0.25 82.65 ± 0.28 81.09 ± 0.17 72.09 ± 0.20 86.29 ± 0.19 76.00 ± 0.36

Prerequisites

  • Linux (Ubuntu)
  • Python >= 3.6
  • NVIDIA GPU + CUDA CuDNN

Installation

  • Clone this repo:
git clone https://github.com/opensuh/DFG/
  • Install PyTorch
    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, you can create a new Conda environment using conda env create -f environment.yml.

DFG train/eval

./scripts/lenet_svhn_train.sh
./scripts/lenet_fmnist_train.sh
./scripts/vgg_stl_train.sh
./scripts/vgg_cinic_train.sh
./scripts/resnet_caltech_train.sh
./scripts/resnet_food_train.sh
  • Evaluate the model (our pre-trained models are in ./pretrained_model)
  • We plan to upload the pre-trained models on our Github page.
./scripts/lenet_svhn_eval.sh
./scripts/lenet_fmnist_eval.sh
./scripts/vgg_stl_eval.sh
./scripts/vgg_cinic_eval.sh
./scripts/resnet_caltech_eval.sh
./scripts/resnet_food_eval.sh

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages