- Authors: Adyasha Maharana, Prateek Yadav, and Mohit Bansal
- Paper: arXiv
- Create a virtual environment and activate it.
python3 -m venv env
source env/bin/activate
- Install dependencies for all datasets except DataComp
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
For DataComp, see requirements of the DataComp codebase and additionally install faiss.
The following commands are usage examples. See our paper for the hyperparameters of each dataset.
# Train model on full dataset to extract training dynamics
python train.py --dataset cifar10 --gpuid 0 --epochs 200 --lr 0.1 --network resnet18 --batch-size 256 --task-name all-data --base-dir ./data-model/cifar10
# Get importance scores and sample embeddings
python generate_importance_score.py --gpuid 0 --base-dir ./data-model/cifar10 --task-name all-data --feature
# Select samples using D2 pruning and train ResNet 18 on the selected coreset
python train.py --dataset cifar10 --gpuid 1 --iterations 40000 --task-name class-lb-graph-n=$N_NEIGHBOR-g=$GAMMA-$CORESET_RATIO \
--base-dir ./data-model/cifar10/class/ --coreset --coreset-mode class --budget-mode uniform --sampling-mode graph \
--data-score-path ./data-model/cifar10/all-data/data-score-all-data.pickle \
--feature-path ./data-model/cifar10/all-data/train-features-all-data.npy \
--coreset-key forgetting --coreset-ratio $CORESET_RATIO --mis-ratio 0.4 --label-balanced \
--n-neighbor $N_NEIGHBOR --gamma $GAMMA --stratas 25 --graph-mode sum --graph-sampling-mode weighted
# Train model on full dataset to extract training dynamics
python train_imagenet.py --epochs 60 --lr 0.1 --scheduler cosine --task-name all-data --base-dir ./data-model/imagenet --data-dir /dir/to/data/imagenet --network resnet34 --batch-size 256 --gpuid 0,1
# Get importance scores and sample embeddings
python generate_importance_score.py --gpuid 0 --base-dir ./data-model/imagenet --task-name all-data --feature
# Select samples using D2 pruning and train ResNet 18 on the selected coreset
python train_imagenet.py --dataset imagenet --gpuid 1 --iterations 40000 --task-name class-lb-graph-n=$N_NEIGHBOR-g=$GAMMA-$CORESET_RATIO \
--base-dir ./data-model/imagenet/graph/ --coreset --coreset-mode graph --budget-mode uniform --sampling-mode graph \
--data-score-path ./data-model/imagenet/all-data/data-score-all-data.pickle \
--feature-path ./data-model/imagenet/all-data/train-features-all-data.npy \
--coreset-key accumulated_margin --coreset-ratio $CORESET_RATIO --mis-ratio 0.4 --label-balanced \
--n-neighbor $N_NEIGHBOR --gamma $GAMMA --stratas 25 --graph-mode sum --graph-sampling-mode weighted
Follow instructions here to first download the DataComp [small] dataset.
# Select samples using D2 pruning from the DataComp dataset
python select_d2_datacomp.py \
--metadata-dir ./datacomp/metadata/ --out-dir ./datacomp/d2/ \
--n-neighbors $N_NEIGHBOR --gamma $GAMMA --fraction $FRACTION \
--feature-type image|text
This script generates a numpy file containing the UIDs of the DataComp subset that can then be used to reshard the DataComp data for training.
Thanks to the authors of Coverage-centric Coreset Selection for High Pruning Rates for releasing their code for evaluating CCS and training ResNet models on CIFAR10, CIFAR100, ImageNet-1K. Much of this codebase has been adapted from their code. Also, thanks to the authors of Beyond neural scaling laws: beating power law scaling via data pruning for releasing the protoypicality scores on ImageNet-1K.
Please cite our paper if you use the
title = {D2 Pruning: Message Passing for Balancing Diversity & Difficulty in Data Pruning},
author = {Adyasha Maharana and Prateek Yadav and Mohit Bansal},
year = {2023},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
eprint = {2310.07931}