Official code for "Rich Feature Construction for the Optimization-Generalization Dilemma"
In Machine Learning, defining a generalized goal (e.g. the invariant goal in out-of-distribution generalization) and finding a path to the goal (e.g. the many optimization tricks) are two key problems. Usually, there is a dilemma between the two. i.e. either the generalization goal is weak/poor or the optimization process is hard. This optimization-generalization dilemma is especially obvious in the out-of-distribution area. This work tries to solve this dilemma by creating a RICH and SIMPLE representation, such that the optimization process becomes easier with the representation. As a result, we can pursue a stronger generalization goal.
Two common questions, in many areas, are: "where is our goal?" and "how to reach the goal from our current position?". A successful project needs to answer both questions. The two questions, however, are contradicted in difficulty. When the goal is ambiguous, normally the path to the goal is blurry. When the path is clear and confident, normally the goal is plain. For the goal of "making a cup of espresso?", for instance, most people can have a clear precise path immediately. On the other hand, "Building a spacecraft to Jupyter" is an ambiguous goal. But most people have no idea about how to achieve it.
Can we build a spacecraft by purely thinking about the "spacecraft"? No. The spacecraft is built based on the development of diverse areas, such as material, computer, engine.
The story above revises the path to hard problems, that is "Search/develop diverse areas (directions). Then a clear path may appear upon them. Otherwise, continuing search more."
The rule above is also the key idea of the proposed Rich Feature Construction (RFC) method.
- wilds==2.0.0
- einops=0.4.1
- python=3.6.13
- pytorch=1.10.2
- torch-geometric=2.0.3
- torch-scatter=2.0.9
- torch-sparse=0.6.12
- torchvision=0.11.3
- tqdm=4.62.3
- transformers=4.17.0
OOD methods are sensitive to the network initialization. We test nine OOD methods, IRMv1, VREx, FISH, SD, IGA, LfF, RSC, CLOvE, fishr, on the ColoredMNIST benchmark. Fig1 shows the OOD performance with different ERM pretrain-epochs. None of the nine OOD methods can work with a random initialzation.
Fig1: Test performance of nine penalized OoD methods as a function of the number of epochs used to pre-train the neural network with ERM. The final OoD testing performance is very dependent on choosing the right number of pretraining epochs, illustrating the challenges of these optimization problems.
To reproduce the results, run:
bash script/coloredmnist/coloredmnist_anneal.sh
Starting from a 'perfect' initialization where the model only uses the robust feature (OOD performance is maximized), what is going on if we continue training these OOD methods? Will they maintain the robustness? or decay to a spurious/singular solution? Fig2 (top) gives the latter answer.
Fig2: Test performance of OoD methods as a function of training epochs. Top: Six OoD methods are trained from a ‘perfect’ initialization where only the robust feature is well learned. The blue star indicates the initial test accuracy. Bottom: The OoD methods are trained from the proposed (frozen) RFC representation.
To reproduce the results (top), run:
bash script/coloredmnist/coloredmnist_perfect_initialization_longtrain.sh
The proposed RFC method creates a rich & simple representation to solve the optimization-generalization dilemma above. Tab1 shows the comparison of Random initialization (Rand), ERM pretrained initialization (ERM), RFC pretrained initialization (RFC / RFC(cf)). The proposed RFC consistantly boost OOD methods.
Tab1: OoD testing accuracy achieved on the COLORMNIST. The first six rows of the table show the results achieved by six OoD methods using respectively random initialization (Rand), ERM initialization (ERM), RFC initialization (RFC). The last column, RFC(cf), reports the performance achieved by running the OoD algorithm on top of the frozen RFC representations. The seventh row reports the results achieved using ERM under the same conditions. The last row reminds us of the oracle performance achieved by a network using data from which the spurious feature (color) has been removed.
To reproduce the results, run:
bash script/coloredmnist/coloredmnist_rfc.sh
A line of works seek OOD generalization by discovering the second easiest-to-find features, such as PI. Here we claim that the second easiest-to-find feature is not the robust solution in general. To showcase the idea, we create a 'InverseColoredMNIST' dataset where the robust feature (digits) is more predictive than the spurious feature (color).
Tab2: OoD test accuracy of PI and OOD/ERM methods on COLOREDMNIST and INVERSECOLOREDMNIST. The OOD/ERM methods are trained on top of a frozen RFC representation.
To reproduce the results, run:
bash script/coloredmnist/inversecoloredmnist_rfc.sh
Network Initialization |
Methods | Test Acc IID Tune |
Test Acc OOD Tune |
scripts | comments |
---|---|---|---|---|---|
- | ERM | 66.6±9.8 | 70.2±8.7 | A | |
ERM | IRMv1 | 68.6±6.8 | 68.5±6.2 | B | |
ERM | vREx | 69.1±8.1 | 69.1±13.2 | C | |
ERM | ERM(cf) | - | - | - | |
ERM | IRMv1(cf) | 69.6±10.5 | 70.7±10.0 | A,D | |
ERM | vREx(cf) | 69.6±10.5 | 70.6±10.0 | A,E | |
ERM | CLOvE(cf) | 69.6±10.5 | 69.2±9.5 | A,F | |
2-RFC | ERM | 72.8±3.2 | 74.7±4.3 | A,G, H, I | set lambda=0 in I |
2-RFC | IRMv1 | 71.6±4.2 | 75.3±4.8 | A,G, H, I | |
2-RFC | vREx | 73.4±3.3 | 76.4±5.3 | A,G, H, J | |
2-RFC | CLOvE | 74.0±4.6 | 76.6±5.3 | A,G, H, K | |
2-RFC | ERM(cf) | 78.2±2.6 | 78.6±2.6 | A,G, H, L | |
2-RFC | IRMv1(cf) | 78.0±2.1 | 79.1±2.1 | A,G, H, L | set lambda=0 in L |
2-RFC | vREx(cf) | 77.9±2.7 | 79.5±2.7 | A,G, H, M | |
2-RFC | CLOvE(cf) | 77.8±2.2 | 78.6±2.6 | A,G, H, N | |
3-RFC | ERM(cf) | 72.9±5.3 | 73.3±5.3 | A,G, O, P, Q | set lambda=0 in Q |
3-RFC | IRMv1(cf) | 72.7±5.5 | 75.5±3.8 | A,G, O, P, Q | |
3-RFC | vREx(cf) | 72.7±5.4 | 75.1±5.3 | A,G, O, P, R | |
3-RFC | vREx(cf) | 72.8±5.4 | 73.2±7.1 | A,G, O, P, S |
If you find our code useful, please consider citing our work using the bibtex:
@inproceedings{zhang2022rich,
title={Rich feature construction for the optimization-generalization dilemma},
author={Zhang, Jianyu and Lopez-Paz, David and Bottou, L{\'e}on},
booktitle={International Conference on Machine Learning},
pages={26397--26411},
year={2022},
organization={PMLR}
}