Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/sangkeun00/analog
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 30, 2023
2 parents 4c8e681 + 710e91a commit e3e5e1b
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 4 deletions.
2 changes: 1 addition & 1 deletion analog/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def compute_per_sample_gradient(fwd, bwd, module):
fwd_unfold = fwd_unfold.reshape(bsz, fwd_unfold.shape[1], -1)
bwd = bwd.reshape(bsz, -1, fwd_unfold.shape[-1])
grad = torch.einsum("ijk,ilk->ijl", bwd, fwd_unfold)
shape = [bsz] + list(module.weight.shape)
shape = [bsz, module.weight.shape[0], -1]
return grad.reshape(shape)
else:
raise ValueError(f"Unsupported module type: {type(module)}")
Expand Down
4 changes: 2 additions & 2 deletions analog/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch.nn as nn

from analog.lora.modules import LoraLinear
from analog.lora.modules import LoraLinear, LoraConv2d


def find_parameter_sharing_group(
Expand Down Expand Up @@ -85,7 +85,7 @@ def add_lora(
elif isinstance(module, nn.Conv1d):
raise NotImplementedError
elif isinstance(module, nn.Conv2d):
raise NotImplementedError
lora_cls = LoraConv2d

psg = find_parameter_sharing_group(name, parameter_sharing_groups)
if parameter_sharing and psg not in shared_modules:
Expand Down
69 changes: 69 additions & 0 deletions analog/lora/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,72 @@ def init_weight(self, init_strategy: str = "random", hessian=None):
) = compute_top_k_singular_vectors(hessian[BACKWARD], self.rank)
self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T)
self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward)


class LoraConv2d(nn.Conv2d):
def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None):
"""Transforms a conv2d layer into a LoraConv2d layer.
Args:
rank (int): The rank of lora
conv (nn.Conv2d): The conv2d layer to transform
"""
in_channels = conv.in_channels
out_channels = conv.out_channels
kernel_size = conv.kernel_size
stride = conv.stride
padding = conv.padding

super().__init__(
in_channels, out_channels, kernel_size, stride, padding, bias=False
)

self.rank = min(rank, self.in_channels, self.out_channels)

self.analog_lora_A = nn.Conv2d(
self.in_channels, self.rank, kernel_size, stride, padding, bias=False
)
self.analog_lora_B = shared_module or nn.Conv2d(
self.rank, self.rank, (1, 1), (1, 1), bias=False
)
self.analog_lora_C = nn.Conv2d(
self.rank, self.out_channels, (1, 1), (1, 1), bias=False
)

nn.init.zeros_(self.analog_lora_B.weight)

self._conv = conv

def forward(self, input) -> torch.Tensor:
result = self._conv(input)
result += self.analog_lora_C(self.analog_lora_B(self.analog_lora_A(input)))

return result

def init_weight(self, projection_type, hessian):
"""Initialize the weight of the LoraLinear layer.
Args:
projection_type (str): The type of projection to use
hessian (dict): The forward and backward hessian of the layer
"""
if projection_type == "random":
nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.analog_lora_C.weight, a=math.sqrt(5))
elif projection_type == "pca":
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(hessian[FORWARD], self.rank)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(hessian[BACKWARD], self.rank)
shape_A = self.analog_lora_A.weight.shape
shape_C = self.analog_lora_C.weight.shape
self.analog_lora_A.weight.data.copy_(
top_r_singular_vector_forward.T.view(shape_A)
)
self.analog_lora_C.weight.data.copy_(
top_r_singular_vector_backward.view(shape_C)
)
15 changes: 15 additions & 0 deletions examples/cifar_influence/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from scipy.stats import pearsonr
import torch


analog_kfac = torch.load("if_analog.pt")
analog_lora_random = torch.load("if_analog_lora.pt")
analog_lora_pca = torch.load("if_analog_pca_conv.pt")
print(
"[KFAC (analog) vs LoRA-random (analog)] pearson:",
pearsonr(analog_kfac, analog_lora_random),
)
print(
"[KFAC (analog) vs LoRA-pca (analog)] pearson:",
pearsonr(analog_kfac, analog_lora_pca),
)
81 changes: 81 additions & 0 deletions examples/cifar_influence/compute_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import time
import torch


from analog import AnaLog
from analog.utils import DataIDGenerator
from analog.analysis import InfluenceFunction

from train import (
get_cifar10_dataloader,
construct_rn9,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DAMPING = 0.01
LISSA_ITER = 100000


def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)):
model = construct_rn9().to(DEVICE)

# Get a single checkpoint (first model_id and last epoch).
model.load_state_dict(
torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu")
)
model.eval()

dataloader_fn = get_cifar10_dataloader
train_loader = dataloader_fn(
batch_size=512, split="train", shuffle=False, subsample=True
)
query_loader = dataloader_fn(
batch_size=1, split="valid", shuffle=False, indices=eval_idxs
)

analog = AnaLog(project="test", config="examples/cifar/config.yaml")

# Gradient & Hessian logging
analog.watch(model)
analog_kwargs = {"log": ["grad"], "hessian": True, "save": True}

id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id, **analog_kwargs):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
analog.finalize()

# Influence Analysis
log_loader = analog.build_log_dataloader()

from analog.analysis import InfluenceFunction

analog.add_analysis({"influence": InfluenceFunction})
query_iter = iter(query_loader)
with analog(log=["grad"], test=True) as al:
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
)
test_loss.backward()
test_log = al.get_log()
start = time.time()
if_scores = analog.influence.compute_influence_all(test_log, log_loader)

# Save
if_scores = if_scores.numpy().tolist()
torch.save(if_scores, "examples/cifar_influence/if_analog.pt")
print("Computation time:", time.time() - start)
print(sorted(if_scores)[:10], sorted(if_scores)[-10:])


if __name__ == "__main__":
single_checkpoint_influence(data_name="cifar10")
85 changes: 85 additions & 0 deletions examples/cifar_influence/compute_influences_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import time
import torch


from analog import AnaLog
from analog.utils import DataIDGenerator
from analog.analysis import InfluenceFunction

from train import (
get_cifar10_dataloader,
construct_rn9,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DAMPING = 0.01
LISSA_ITER = 100000


def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)):
model = construct_rn9().to(DEVICE)

# Get a single checkpoint (first model_id and last epoch).
model.load_state_dict(
torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu")
)
model.eval()

dataloader_fn = get_cifar10_dataloader
train_loader = dataloader_fn(
batch_size=512, split="train", shuffle=False, subsample=True
)
query_loader = dataloader_fn(
batch_size=1, split="valid", shuffle=False, indices=eval_idxs
)

analog = AnaLog(project="test", config="examples/cifar_influence/config.yaml")

# Gradient & Hessian logging
analog.watch(model)
analog_kwargs = {"log": ["grad"], "hessian": True, "save": False}

id_gen = DataIDGenerator()
for epoch in range(2):
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id, **analog_kwargs):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
analog.finalize()
if epoch == 0:
analog_kwargs.update({"save": True})
analog.add_lora(model, parameter_sharing=False)

# Influence Analysis
log_loader = analog.build_log_dataloader()

from analog.analysis import InfluenceFunction

analog.add_analysis({"influence": InfluenceFunction})
query_iter = iter(query_loader)
with analog(log=["grad"], test=True) as al:
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
)
test_loss.backward()
test_log = al.get_log()
start = time.time()
if_scores = analog.influence.compute_influence_all(test_log, log_loader)

# Save
if_scores = if_scores.numpy().tolist()
torch.save(if_scores, "examples/cifar_influence/if_analog_pca.pt")
print("Computation time:", time.time() - start)
print(sorted(if_scores)[:10], sorted(if_scores)[-10:])


if __name__ == "__main__":
single_checkpoint_influence(data_name="cifar10")
5 changes: 5 additions & 0 deletions examples/cifar_influence/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
storage:
type: default
log_dir: "./analog/log/lora"
lora:
init: pca
104 changes: 104 additions & 0 deletions examples/cifar_influence/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# We use an example from the TRAK repository:
# https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb.


import os
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler

from utils import get_cifar10_dataloader, construct_rn9
from utils import set_seed

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train(
model,
loader,
lr=0.4,
epochs=24,
momentum=0.9,
weight_decay=5e-4,
lr_peak_epoch=5,
save_name="default",
model_id=0,
save=True,
):
opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
iters_per_epoch = len(loader)
# Cyclic LR with single triangle
lr_schedule = np.interp(
np.arange((epochs + 1) * iters_per_epoch),
[0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
[0, 1, 0],
)
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
scaler = GradScaler()
loss_fn = CrossEntropyLoss()

for epoch in tqdm(range(epochs)):
set_seed(model_id * 10_061 + epoch + 1)
for images, labels in loader:
images = images.to(DEVICE)
labels = labels.to(DEVICE)
opt.zero_grad(set_to_none=True)
with autocast():
out = model(images)
loss = loss_fn(out, labels)

scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
scheduler.step()

if save:
torch.save(
model.state_dict(),
f"checkpoints/{save_name}_{model_id}_epoch_{epochs-1}.pt",
)

return model


def main(dataset="cifar10"):
os.makedirs("checkpoints", exist_ok=True)

if dataset == "cifar10":
train_loader = get_cifar10_dataloader(
batch_size=512, split="train", shuffle=True, subsample=True
)
valid_loader = get_cifar10_dataloader(
batch_size=512, split="val", shuffle=False
)
else:
raise NotImplementedError

# you can modify the for loop below to train more models
for model_id in tqdm(range(1), desc="Training models.."):
model = construct_rn9().to(memory_format=torch.channels_last).to(DEVICE)

model = train(model, train_loader, save_name=dataset, model_id=model_id)

model = model.eval()

model.eval()
with torch.no_grad():
total_correct, total_num = 0.0, 0.0
for images, labels in tqdm(valid_loader):
images = images.to(DEVICE)
labels = labels.to(DEVICE)
with autocast():
out = model(images)
total_correct += out.argmax(1).eq(labels).sum().cpu().item()
total_num += images.shape[0]

print(f"Accuracy: {total_correct / total_num * 100:.1f}%")


if __name__ == "__main__":
main()
Loading

0 comments on commit e3e5e1b

Please sign in to comment.