This is the repository accompanying the paper “Robust Meta-Representation Learning via Global Label Inference and Classification” containing the code necessary to run the different steps of the pipeline as outlined in the appendix.
The code rely on setting / exporting the shell variable variable WORKSPACE
to point to a
directory which house the saves and datasets. A suggestion is to either put this
in the home directory or wherever you keep your datasets. The workspace
directory have two sub-directories checkpoint
in which we keep the model
weights for the pipeline and metaL_data
which stores all of the base datasets
we use.
Below is an example of the populated workspace
directory output from tree -d
You can initialize this directory in your home directory running
make initialize_workspace
The pipeline of steps that implements the model is instantiated by running the following order of the files in the top directory of the repo
learn_meta_repr.py
learn_labeler.py
fine_tune.py
For evaluating the models run meta_eval.py
and to train the supervised oracle models run sup_baseline.py
. See Pipeline steps for details on how to do this correctly.
The default arguments of the pipeline steps are defined in the corresponding yaml file of the same base name as the pipeline python file in the config
directory. We use hydra
as the framework.
Contains the dataset functionality including generating the datasets.
Defines the models and loading functionality.
Contains log files which is written to when running a pipeline step.
We have provided a makefile with make rules for downloading and untarring all of
the datasets used. The datasets are provided in a form that is ready to use with
the code as long as they can be found in the right directories. The makefiles
assumes that the command line tools wget
and tar
is available on your
platform. If this is not the case, you can download and untar them manually
using this link to the data of the repository.
To download the datasets, please run
make download_all
and to untar them into the right directory, run
make unpack_datasets
The following datasets were generated using the code in
/dataset/ximagenet-tools
which can produce few-shot base datasets derived from
ImageNet on the fly (built off mini-imagenet-tools). In particular we made sure
that the meta-train sets of mini-60 and tiered-780 do not overlap with the
test sets of /mini/ImageNet and /tiered/ImageNet. If you want to generate these
datasets or other ImageNet-derived datasets, see the sub-project README.md.
To generate the datasets in mixed
we rely on the initial data preparation
pipeline of the meta-dataset repo. In order to generate the necessary
sub-datasets follow the instructions for
to generate the datasets in the form of a directory with the data format being
tfrecords
. Move the resulting datasets in $RECORDS
to
$WORKSPACE/metaL_data
using
mv -v $RECORDS/* $WORKSPACE/metaL_data
Now, the datasets may be processed into the right form by running
python $MELA/dataset/tfrecords_to_pickle.py --dataset $DATASET
To create the H-aircraft
dataset, you just need to run
tfrecords_to_pickle.py
with argument $DATASET = aircraft
. The resulting
datasets will be found under the respective directories in $WORKSPACE
in terms
of pickle files which will be loaded by the pipeline.
The makefile contains functionality for initializing the workspace directory and
downloading and unpacking all of the datasets. It assumes that the OS you use
has the command line tools wget
, and tar
installed.
We provide a conda env.yml
file which allows to install the necessary
packages. Note that the code requires a GPU and a CPU with enough RAM to run.
Running the pipeline steps relies on using the same dataset at each step
together with other necessary arguments that should be consistent throughout
(for example the architecture used). Most steps will produce artifacts in terms
of log files, found in logs
together with models saved with torch output to
$WORKSPACE/checkpoints/label_learn/saves/
.
The arguments shared between each pipline step is as follows
- logger_name
- (string) Base name of the logger file, usually this should be fixed
- trial
- (string) Numbering scheme of runs (if you run things several times with similar arguments)
- dataset
- (string) Dataset to use. Options: mixed, h_aircraft, miniImageNet, tieredImageNet, mini60, tiered780
- sample_shape
- (few_shot / flat) if the dataset should be a few-shot dataset with tasks or a flat supervised learning dataset
- fixed_db
- (bool) Deterministically sample the tasks of the dataset
- no_replacement
- (bool) If we are to use the no replacement dataset sampling (called GFSL in the paper)
- sim_imbalance
- Not used
- n_ways
- (int) Number of classes in each task
- n_shots
- (int) Number of samples per class in support set
- n_queries
- (int) Number of samples per class in query set
- val_n_ways
- (int) Number of classes during test / validation time. Normally set to n_ways
- val_n_shots
- (int) Number of samples per class during test / validation time. Normally set to n_shots
- model
- (string) Architecture to use:
resnet12
orresnet18
- feat_dim
- (int) Dimensionality of feature space, needs to be set according to model used (usually 640)
- lam
- (float) Regularization strength used in Ridge Regression
- train_db_size
- (int) Number of tasks / samples in the dataset (overridden by no_replacement or if bigger than the underlying size of the dataset)
- test_db_size
- (int) Number of tasks to validate over for each sub-dataset
- num_workers
- (int) Number of workers used
- epochs
- (int) Number of Epochs to train for
- normalize_lam
- Not used
- data_aug
- (bool) Whether to use data augmentation
- rotate_aug
- (bool) Whether to use rotation augmentation
- SGD
- (bool) Whether to use SGD to Adam / AdamW
- learning_rate
- (float) Learning rate
- lr_decay_epochs
- (string) String of the form “e1,e2,…,em” where each “ei” is an epoch that we multiply the learning rate by lr_decay_rate (below)
- lr_decay_rate
- (float) Decay rate used for annealing the learning rate
- weight_decay
- (float) Weight decay to use
- momentum
- (float) Momentum parameter of the optimization algorithm
- test_C
- (float) C (inverse regularization strength) to use in the logistic regression classifier during test / validation time
- use_bias
- (bool) Include bias in logistic regression classifier
- is_norm
- (bool) Normalize each feature mapped instance to have unit norm at test time
- progress
- (bool) Show progress bar
Learn representation using only the locally available labels of each task. Specific arguments
- pretrained_model
- (string) Name of file of saved model to load (normally not used in this step)
This step will output a log file to the log directory and a saved model file to the save
directory in the $WORKSPACE
.
Using the learned representation (in the form of the saved model file) from learn_meta_repr.py
, learn a labeler in order to label the few-shot dataset to infer labels and thus a flat dataset and use this to train a model using supervised learning. Specific arguments
- label_recovery_model
- (string) Same as model, the model used in the
learn_meta_repr.py
step - train_model
- (string) The model to use when training after we have inferred a flat dataset
- pretrained_labeler
- (string) Name of file of saved feature map model learned in
learn_meta_repr.py
to load for inferring labels using labeler - pretrained_centroids
- (string) Name of file of saved centroids if we are to load these directly from a previous run of this step
- pretrained_model
- Not used
- K
- (int) Number of initial centroids to use for the labelling algorithm
- std_factor
- (float) Aggression factor (in terms of standard deviation) for how aggressively we should prune centroids in labelling algorithm
- data_aug
- (bool) Whether to use data augmentation for labelling step (should be set to the value of the arguments used to produce the saved model from
learn_meta_repr.py
, usuallyfalse
) - rotate_aug
- (bool) Whether to use rotation augmentation for labelling step (should be set to the value of the arguments used to produce the saved model from
learn_meta_repr.py
, usuallyfalse
) - sup_data_aug
- (boot) Whether to use data augmentation when doing supervised training using the inferred flat dataset
- sup_rotate_aug
- (boot) Whether to use rotation augmentation when doing supervised training using the inferred flat dataset
This step will output a log file to the log directory and a saved model and centroids file to the save
directory in the $WORKSPACE
.
Fine tune the model in learn_labeler.py
using a residual MLP upon the frozen feature map output from learn_labeler.py
. Specific arguments
This step will output a log file to the log directory and a saved model file to the save
directory in the $WORKSPACE
.
Evaluate a saved model on 5-way, 1-shot / 5-shot few-shot setting of a dataset. Specific arguments
- pretrained_model
- (string) Name of file of saved feature map model to evaluate
Outputs the results to a log file.
Train a supervised baseline that have access to the true labels.
@article{wang_2023,
doi = {10.1109/tpami.2023.3328184},
url = {https://doi.org/10.1109%2Ftpami.2023.3328184},
year = 2023,
publisher = {Institute of Electrical and Electronics Engineers ({IEEE})},
pages = {1--16},
author = {Ruohan Wang and John Isak Texas Falk and Massimiliano Pontil and Carlo Ciliberto},
title = {Robust Meta-Representation Learning via Global Label Inference and Classification},
journal = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence}
}
MIT