-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/sangkeun00/analog
- Loading branch information
Showing
10 changed files
with
493 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from scipy.stats import pearsonr | ||
import torch | ||
|
||
|
||
analog_kfac = torch.load("if_analog.pt") | ||
analog_lora_random = torch.load("if_analog_lora.pt") | ||
analog_lora_pca = torch.load("if_analog_pca_conv.pt") | ||
print( | ||
"[KFAC (analog) vs LoRA-random (analog)] pearson:", | ||
pearsonr(analog_kfac, analog_lora_random), | ||
) | ||
print( | ||
"[KFAC (analog) vs LoRA-pca (analog)] pearson:", | ||
pearsonr(analog_kfac, analog_lora_pca), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import time | ||
import torch | ||
|
||
|
||
from analog import AnaLog | ||
from analog.utils import DataIDGenerator | ||
from analog.analysis import InfluenceFunction | ||
|
||
from train import ( | ||
get_cifar10_dataloader, | ||
construct_rn9, | ||
) | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
DAMPING = 0.01 | ||
LISSA_ITER = 100000 | ||
|
||
|
||
def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)): | ||
model = construct_rn9().to(DEVICE) | ||
|
||
# Get a single checkpoint (first model_id and last epoch). | ||
model.load_state_dict( | ||
torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu") | ||
) | ||
model.eval() | ||
|
||
dataloader_fn = get_cifar10_dataloader | ||
train_loader = dataloader_fn( | ||
batch_size=512, split="train", shuffle=False, subsample=True | ||
) | ||
query_loader = dataloader_fn( | ||
batch_size=1, split="valid", shuffle=False, indices=eval_idxs | ||
) | ||
|
||
analog = AnaLog(project="test", config="examples/cifar/config.yaml") | ||
|
||
# Gradient & Hessian logging | ||
analog.watch(model) | ||
analog_kwargs = {"log": ["grad"], "hessian": True, "save": True} | ||
|
||
id_gen = DataIDGenerator() | ||
for inputs, targets in train_loader: | ||
data_id = id_gen(inputs) | ||
with analog(data_id=data_id, **analog_kwargs): | ||
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) | ||
model.zero_grad() | ||
outs = model(inputs) | ||
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") | ||
loss.backward() | ||
analog.finalize() | ||
|
||
# Influence Analysis | ||
log_loader = analog.build_log_dataloader() | ||
|
||
from analog.analysis import InfluenceFunction | ||
|
||
analog.add_analysis({"influence": InfluenceFunction}) | ||
query_iter = iter(query_loader) | ||
with analog(log=["grad"], test=True) as al: | ||
test_input, test_target = next(query_iter) | ||
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) | ||
model.zero_grad() | ||
test_out = model(test_input) | ||
test_loss = torch.nn.functional.cross_entropy( | ||
test_out, test_target, reduction="sum" | ||
) | ||
test_loss.backward() | ||
test_log = al.get_log() | ||
start = time.time() | ||
if_scores = analog.influence.compute_influence_all(test_log, log_loader) | ||
|
||
# Save | ||
if_scores = if_scores.numpy().tolist() | ||
torch.save(if_scores, "examples/cifar_influence/if_analog.pt") | ||
print("Computation time:", time.time() - start) | ||
print(sorted(if_scores)[:10], sorted(if_scores)[-10:]) | ||
|
||
|
||
if __name__ == "__main__": | ||
single_checkpoint_influence(data_name="cifar10") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import time | ||
import torch | ||
|
||
|
||
from analog import AnaLog | ||
from analog.utils import DataIDGenerator | ||
from analog.analysis import InfluenceFunction | ||
|
||
from train import ( | ||
get_cifar10_dataloader, | ||
construct_rn9, | ||
) | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
DAMPING = 0.01 | ||
LISSA_ITER = 100000 | ||
|
||
|
||
def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)): | ||
model = construct_rn9().to(DEVICE) | ||
|
||
# Get a single checkpoint (first model_id and last epoch). | ||
model.load_state_dict( | ||
torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu") | ||
) | ||
model.eval() | ||
|
||
dataloader_fn = get_cifar10_dataloader | ||
train_loader = dataloader_fn( | ||
batch_size=512, split="train", shuffle=False, subsample=True | ||
) | ||
query_loader = dataloader_fn( | ||
batch_size=1, split="valid", shuffle=False, indices=eval_idxs | ||
) | ||
|
||
analog = AnaLog(project="test", config="examples/cifar_influence/config.yaml") | ||
|
||
# Gradient & Hessian logging | ||
analog.watch(model) | ||
analog_kwargs = {"log": ["grad"], "hessian": True, "save": False} | ||
|
||
id_gen = DataIDGenerator() | ||
for epoch in range(2): | ||
for inputs, targets in train_loader: | ||
data_id = id_gen(inputs) | ||
with analog(data_id=data_id, **analog_kwargs): | ||
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) | ||
model.zero_grad() | ||
outs = model(inputs) | ||
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") | ||
loss.backward() | ||
analog.finalize() | ||
if epoch == 0: | ||
analog_kwargs.update({"save": True}) | ||
analog.add_lora(model, parameter_sharing=False) | ||
|
||
# Influence Analysis | ||
log_loader = analog.build_log_dataloader() | ||
|
||
from analog.analysis import InfluenceFunction | ||
|
||
analog.add_analysis({"influence": InfluenceFunction}) | ||
query_iter = iter(query_loader) | ||
with analog(log=["grad"], test=True) as al: | ||
test_input, test_target = next(query_iter) | ||
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) | ||
model.zero_grad() | ||
test_out = model(test_input) | ||
test_loss = torch.nn.functional.cross_entropy( | ||
test_out, test_target, reduction="sum" | ||
) | ||
test_loss.backward() | ||
test_log = al.get_log() | ||
start = time.time() | ||
if_scores = analog.influence.compute_influence_all(test_log, log_loader) | ||
|
||
# Save | ||
if_scores = if_scores.numpy().tolist() | ||
torch.save(if_scores, "examples/cifar_influence/if_analog_pca.pt") | ||
print("Computation time:", time.time() - start) | ||
print(sorted(if_scores)[:10], sorted(if_scores)[-10:]) | ||
|
||
|
||
if __name__ == "__main__": | ||
single_checkpoint_influence(data_name="cifar10") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
storage: | ||
type: default | ||
log_dir: "./analog/log/lora" | ||
lora: | ||
init: pca |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# We use an example from the TRAK repository: | ||
# https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb. | ||
|
||
|
||
import os | ||
from pathlib import Path | ||
from tqdm.auto import tqdm | ||
import numpy as np | ||
import torch | ||
from torch.cuda.amp import GradScaler, autocast | ||
from torch.nn import CrossEntropyLoss | ||
from torch.optim import SGD, lr_scheduler | ||
|
||
from utils import get_cifar10_dataloader, construct_rn9 | ||
from utils import set_seed | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def train( | ||
model, | ||
loader, | ||
lr=0.4, | ||
epochs=24, | ||
momentum=0.9, | ||
weight_decay=5e-4, | ||
lr_peak_epoch=5, | ||
save_name="default", | ||
model_id=0, | ||
save=True, | ||
): | ||
opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) | ||
iters_per_epoch = len(loader) | ||
# Cyclic LR with single triangle | ||
lr_schedule = np.interp( | ||
np.arange((epochs + 1) * iters_per_epoch), | ||
[0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch], | ||
[0, 1, 0], | ||
) | ||
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__) | ||
scaler = GradScaler() | ||
loss_fn = CrossEntropyLoss() | ||
|
||
for epoch in tqdm(range(epochs)): | ||
set_seed(model_id * 10_061 + epoch + 1) | ||
for images, labels in loader: | ||
images = images.to(DEVICE) | ||
labels = labels.to(DEVICE) | ||
opt.zero_grad(set_to_none=True) | ||
with autocast(): | ||
out = model(images) | ||
loss = loss_fn(out, labels) | ||
|
||
scaler.scale(loss).backward() | ||
scaler.step(opt) | ||
scaler.update() | ||
scheduler.step() | ||
|
||
if save: | ||
torch.save( | ||
model.state_dict(), | ||
f"checkpoints/{save_name}_{model_id}_epoch_{epochs-1}.pt", | ||
) | ||
|
||
return model | ||
|
||
|
||
def main(dataset="cifar10"): | ||
os.makedirs("checkpoints", exist_ok=True) | ||
|
||
if dataset == "cifar10": | ||
train_loader = get_cifar10_dataloader( | ||
batch_size=512, split="train", shuffle=True, subsample=True | ||
) | ||
valid_loader = get_cifar10_dataloader( | ||
batch_size=512, split="val", shuffle=False | ||
) | ||
else: | ||
raise NotImplementedError | ||
|
||
# you can modify the for loop below to train more models | ||
for model_id in tqdm(range(1), desc="Training models.."): | ||
model = construct_rn9().to(memory_format=torch.channels_last).to(DEVICE) | ||
|
||
model = train(model, train_loader, save_name=dataset, model_id=model_id) | ||
|
||
model = model.eval() | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
total_correct, total_num = 0.0, 0.0 | ||
for images, labels in tqdm(valid_loader): | ||
images = images.to(DEVICE) | ||
labels = labels.to(DEVICE) | ||
with autocast(): | ||
out = model(images) | ||
total_correct += out.argmax(1).eq(labels).sum().cpu().item() | ||
total_num += images.shape[0] | ||
|
||
print(f"Accuracy: {total_correct / total_num * 100:.1f}%") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.