This repo contains the sample code for reproducing the results of our NeurIPS 2023: Understanding and Improving Feature Learning for Out-of-Distribution Generalization, which has also been presented as spotlight at ICLR DG, and at ICML SCIS Workshop. 😆😆😆
Updates:
- Camera-ready version of the paper is updated link!
- Detailed running instructions will be released soon!
Empirical risk minimization (ERM) is the de facto objective adopted in Machine Learning and obtains impressive generalization performance. Nevertheless, ERM is shown to be prone to spurious correlations, and is suspected to learn predictive but spurious features for minimizing the empirical risk. However, recently Rosenfeld et al., 2022;Kirichenko et al., 2022 empirically show that ERM already learn invariant features that hold an invariant relation with the label for in-distribution and Out-of-Distribution (OOD) generalization.
We resolve the puzzle by theoretically proving that ERM essentially learns both spurious and invariant features. Meanwhile, we also find OOD objectives such as IRMv1 can hardly learn new features even at the begining of the optimization. Therefore, when optimizing OOD objectives such as IRMv1, pre-training the model with ERM is usually necessary for satisfactory performance. As shown in the right subfigure, the OOD performance of various OOD objective first grows with more ERM pre-training epochs.
However, ERM has its preference to learning features depending on the inductive biases of the dataset and the architecture. The limited feature learning can pose a bottleneck for OOD generalization. Therefore, we propose Feature Augmented Training (FeAT), that aims to learn all features so long as they are useful for generalization. Iteratively, FeAT divides the training data
For more interesting stories of rich feature learning, please read more into the repositories Bonsai, RRL and the blog by Jianyu. 😆
The whole code base contain four parts, corresponding to experiments presented in the paper:
ColoredMNIST
: Proof of Concept on ColoredMNISTWILDS
: Verification of FeAT in WILDS
We are running with cuda=10.2 and python=3.8.12 with the following key libraries:
wilds==2.0.0
torch==1.9.0
The corresponding code is in the folder ColoredMNIST. The code is modified from RFC.
The corresponding code is in the folder WILDS. The code is modified from PAIR and spurious-feature-learning.
If you find our paper and repo useful, please cite our paper:
@inproceedings{
chen2023FeAT,
title={Understanding and Improving Feature Learning for Out-of-Distribution Generalization},
author={Yongqiang Chen and Wei Huang and Kaiwen Zhou and Yatao Bian and Bo Han and James Cheng},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=eozEoAtjG8}
}