MOKD: Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence
This repository contains the source codes for reproducing the results of ICML'24 paper: MOKD: Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence.
Author List: Hongduan Tian, Feng Liu, Tongliang Liu, Bo Du, Yiu-ming Cheung, Bo Han.
Current works regarding cross-domain few-shot classification mainly focus on adapting a simple transformation head on top of a frozen pretrained backbone (e.g. ResNet-18) by optimizing the nearest centroid classifier loss (a.k.a. NCC-based loss). However, the undesirable phenomenon that there exists high similarity between samples from different classes is observed during the adaptation phase. The high similarity may induce uncertainty and further result in misclassification of data samples.
To solve this problem, we propose a bi-level optimization framework maximizing optimized kernel dependence (MOKD) to learn a set of class-specific representations that matches the cluster structures indicated by the label information. Specifically, MOKD first optimizes the kernel used in Hilbert-Schmidt Independence Criterion to obtain the optimized kernel HSIC where the test power is maximized for precise detection of dependence. Then, the optimized kernel HSIC is further optimized to simultaneously maximize the dependence between representations and labels while minimize the dependence among all samples.
In our experiments, the main dependences required are the following libraries:
Python 3.6 or greater (Ours: Python 3.8)
PyTorch 1.0 or greater (Ours: torch=1.7.1, torchvision=0.8.2)
TensorFlow 1.14 or greater (Ours: TensorFlow=2.10)
tqdm (Ours: 4.64.1)
tabulate (0.8.10)
-
Follow Meta-Dataset repository to prepare
ILSVRC_2012
,Omniglot
,Aircraft
,CU_Birds
,Textures (DTD)
,Quick Draw
,Fungi
,VGG_Flower
,Traffic_Sign
andMSCOCO
datasets. -
Follow CNAPs repository to prepare
MNIST
,CIFAR-10
andCIFAR-100
datasets.
In this paper, we follow URL and use ResNet-18 as the frozen backbone in all our experiments. For reproduction, two ways are provided:
Train your own backbone. You can train the ResNet-18 backbone from scratch by yourself. The pretraining mainly contains two phases: domain-specific pretraining and universal backbone distillation.
To train the single domain-specific learning backbones (on 8 seen domains), run:
./scripts/train_resnet18_sdl.sh
Then, distill the model by running:
./scripts/train_resnet18_url.sh
Use the released backbones. URL repository has released both universal backbone and single domain backbone. For simplicity, you can directly use the released model.
The backbones can be downloaded with the above links. To download the pretrained URL model, one can use gdown
(installed by pip install gdown
) and execute the following command in the root directory of this project:
gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip # Universal backbone
gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip # Domain specific backbones
In this way, the backbones are donwnloaded. Please create the ./saved_results
directory and place the backbone weights in it.
To evaluate the MODK, you can run:
./scripts/test_hsic_pa.sh
Specifically, the running command is:
python hsic_loss.py --model.name=url
--model.dir ./url
--data.imgsize=84\
--seed=41 \
--test_size=600 \
--kernel.type=rbf \
--epsilon=1e-5 \
--test.type=standard \
--experiment_name=mokd_seed41
The hyperparameters can be modified for different experiments:
model_name: ['sdl', 'url']
:sdl
means using single domain backbone;url
means using universal backbone.model.dir
: Path to the backbone weights.seed
: The random seed. All our results are the average of seed 41-45.kernel.type ['linear', 'rbf', 'imq']
: Select different kernels to run MOKD.test.type ['standard', '5shot', '1shot']
: Different task modes.standard
means vary-way vary-shot tasks;5shot
means vary-way 5-shot tasks;1shot
means 5-way 1-shot tasks.
To evaluate Pre-classifier Alignment (PA), which is the typical case of URL, run:
./scripts/test_resnet18_pa.sh
The repository is built mainly upon these repositories:
[1] Li et al. Universal representation learning from multiple domains for few-shot classification, ICCV 2021.
[2] Triantafillou et al. Meta-dataset: A dataset of datasets for learning to learn from few examples, ICLR 2020.
@inproceedings{tian2024mokd,
title={MOKD:Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence},
author={Hongduan Tian and Feng Liu and Tongliang Liu and Bo Du and Yiu-ming Cheung and Bo Han},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}