本文参照https://github.com/yinboc/prototypical-network-pytorch的pytorch版本进行复现
Prototypical Networks for Few-shot Learning
摘要: 我们提出了原型网络来解决小样本分类问题:分类器能够通过很少的几个在测试集上未出现新的类别的样本来获得在训练集上未出现过的新的类别的分类能力。 原型网络学习一个度量空间,在这个度量空间中,计算到每个类别的原型表征的距离来进行分类。和最近的小样本学习方法相比,它们体现出了一种更简单的归纳偏差,这种归纳偏差在有限的数据集上是有益的,并且实现了优秀的性能。我们的分析指出,相比于最近涉及复杂架构选择和元学习的方法,一些简单的设计决策可以产生实质性的改进。我们进一步将原型网络扩展到零样本学习上,并在CU-Birds数据集上取得了SOTA的结果。
本项目是 prototypical Network 在 Paddle 2.1.2上的开源实现。
名称 | 值 |
---|---|
python | 3.7.11 |
GPU | RTX3090 |
框架 | PaddlePaddle2.1.2 |
Cuda | 11.2 |
Cudnn | 8.2 |
# 克隆本仓库
git clone git@github.com:skingorz/prototypeNet-paddle.git
# 进入项目文件夹
cd prototypeNet-paddle
# 修改文件描述符限制
ulimit -n 20480
# 本地安装
conda env create -f environment.yml
数据集下载:miniImageNet
为了避免训练过程中对文件的频繁读取,首先将数据集读取并保存到下来
python data/dataprocess.py
exps/exp-v1/config.yaml
为5way-1shot的配置文件,
exps/exp-v2/config.yaml
为5way-5shot的配置文件,其中,需要将datapath
值改为上一步生成的pkl所在路径
max_epoch
训练epoch个数save_epoch
每隔多少个epoch保存一次模型shot
每个类别支持集的样本数query
每个类别查询集的样本上train_way
训练时每个episode包含的类别数test_way
测试时每个episode包含的类别数datasets
数据集,目前仅有mini-imageNetdatapath
处理后的数据pkl所在目录save_path
模型保存路径gpu
gpu选择seed
随机种子的设置lr
学习率stepSize
学习率衰减间隔gamma
衰减率load
加载模型的路径result
最终测试结果保存路径batch
测试时的episode数量
FLAGS_cudnn_deterministic=True python tools/train.py --config "exps/exp-v1/config.yaml"
FLAGS_cudnn_deterministic=True python tools/train.py --config "exps/exp-v2/config.yaml"
FLAGS_cudnn_deterministic=True python tools/test.py --config "exps/exp-v1/config.yaml"
FLAGS_cudnn_deterministic=True python tools/test.py --config "exps/exp-v2/config.yaml"
下表展示了我们的复现代码和paperwithcode上已有的结果对比,在5-way-1-shot上我们的方法超过了论文的性能0.87%,也高于大部分复现结果。在5-way-5-shot上,我们比论文性能略高0.02%,且paperwithcode上目前并无达到论文性能的代码。
method | 5-way-1-shot | 5-way-5-shot |
---|---|---|
paper | 49.42 ± 0.78 | 68.20 ± 0.66 |
oscarknagg/few-shot | 48.0 | 66.2 |
yinboc/prototypical-network-pytorch | 49.1 | 66.9 |
schatty/prototypical-networks-tf | 43.5 | 66.0 |
minseop-aitrics/FewshotLearning | 52.547 ± 0.766 | 67.673 ± 0.648 |
KamalM8/Few-Shot-learning-Fashion | 48.0 | 66.2 |
WangTianduo/Prototypical-Networks | 42.48 | 64.7 |
Michedev/Prototypical-Networks-Few-Zero-Shot | 49.75 | 66.60 |
This repo(paddlepaddle) | 50.29 ± 0.64 | 68.22 ± 0.49 |
@inproceedings{snell2017prototypical,
title={Prototypical networks for few-shot learning},
author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
booktitle={Proceedings of the 31st International Conference on Neural Information Processing Systems},
pages={4080--4090},
year={2017}
}