diff --git a/analog/logging.py b/analog/logging.py index 7e6f2d6b..ee58eb72 100644 --- a/analog/logging.py +++ b/analog/logging.py @@ -27,7 +27,7 @@ def compute_per_sample_gradient(fwd, bwd, module): fwd_unfold = fwd_unfold.reshape(bsz, fwd_unfold.shape[1], -1) bwd = bwd.reshape(bsz, -1, fwd_unfold.shape[-1]) grad = torch.einsum("ijk,ilk->ijl", bwd, fwd_unfold) - shape = [bsz] + list(module.weight.shape) + shape = [bsz, module.weight.shape[0], -1] return grad.reshape(shape) else: raise ValueError(f"Unsupported module type: {type(module)}") diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 831619a2..2189c378 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -2,7 +2,7 @@ import torch.nn as nn -from analog.lora.modules import LoraLinear +from analog.lora.modules import LoraLinear, LoraConv2d def find_parameter_sharing_group( @@ -85,7 +85,7 @@ def add_lora( elif isinstance(module, nn.Conv1d): raise NotImplementedError elif isinstance(module, nn.Conv2d): - raise NotImplementedError + lora_cls = LoraConv2d psg = find_parameter_sharing_group(name, parameter_sharing_groups) if parameter_sharing and psg not in shared_modules: diff --git a/analog/lora/modules.py b/analog/lora/modules.py index 494bcae2..4ad66042 100644 --- a/analog/lora/modules.py +++ b/analog/lora/modules.py @@ -58,3 +58,72 @@ def init_weight(self, init_strategy: str = "random", hessian=None): ) = compute_top_k_singular_vectors(hessian[BACKWARD], self.rank) self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T) self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward) + + +class LoraConv2d(nn.Conv2d): + def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None): + """Transforms a conv2d layer into a LoraConv2d layer. + + Args: + rank (int): The rank of lora + conv (nn.Conv2d): The conv2d layer to transform + """ + in_channels = conv.in_channels + out_channels = conv.out_channels + kernel_size = conv.kernel_size + stride = conv.stride + padding = conv.padding + + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, bias=False + ) + + self.rank = min(rank, self.in_channels, self.out_channels) + + self.analog_lora_A = nn.Conv2d( + self.in_channels, self.rank, kernel_size, stride, padding, bias=False + ) + self.analog_lora_B = shared_module or nn.Conv2d( + self.rank, self.rank, (1, 1), (1, 1), bias=False + ) + self.analog_lora_C = nn.Conv2d( + self.rank, self.out_channels, (1, 1), (1, 1), bias=False + ) + + nn.init.zeros_(self.analog_lora_B.weight) + + self._conv = conv + + def forward(self, input) -> torch.Tensor: + result = self._conv(input) + result += self.analog_lora_C(self.analog_lora_B(self.analog_lora_A(input))) + + return result + + def init_weight(self, projection_type, hessian): + """Initialize the weight of the LoraLinear layer. + + Args: + projection_type (str): The type of projection to use + hessian (dict): The forward and backward hessian of the layer + """ + if projection_type == "random": + nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.analog_lora_C.weight, a=math.sqrt(5)) + elif projection_type == "pca": + ( + top_r_singular_vector_forward, + top_r_singular_value_forward, + ) = compute_top_k_singular_vectors(hessian[FORWARD], self.rank) + ( + top_r_singular_vector_backward, + top_r_singular_value_backward, + ) = compute_top_k_singular_vectors(hessian[BACKWARD], self.rank) + shape_A = self.analog_lora_A.weight.shape + shape_C = self.analog_lora_C.weight.shape + self.analog_lora_A.weight.data.copy_( + top_r_singular_vector_forward.T.view(shape_A) + ) + self.analog_lora_C.weight.data.copy_( + top_r_singular_vector_backward.view(shape_C) + ) diff --git a/examples/cifar_influence/compare.py b/examples/cifar_influence/compare.py new file mode 100644 index 00000000..56664b19 --- /dev/null +++ b/examples/cifar_influence/compare.py @@ -0,0 +1,15 @@ +from scipy.stats import pearsonr +import torch + + +analog_kfac = torch.load("if_analog.pt") +analog_lora_random = torch.load("if_analog_lora.pt") +analog_lora_pca = torch.load("if_analog_pca_conv.pt") +print( + "[KFAC (analog) vs LoRA-random (analog)] pearson:", + pearsonr(analog_kfac, analog_lora_random), +) +print( + "[KFAC (analog) vs LoRA-pca (analog)] pearson:", + pearsonr(analog_kfac, analog_lora_pca), +) diff --git a/examples/cifar_influence/compute_influence.py b/examples/cifar_influence/compute_influence.py new file mode 100644 index 00000000..b7785318 --- /dev/null +++ b/examples/cifar_influence/compute_influence.py @@ -0,0 +1,81 @@ +import time +import torch + + +from analog import AnaLog +from analog.utils import DataIDGenerator +from analog.analysis import InfluenceFunction + +from train import ( + get_cifar10_dataloader, + construct_rn9, +) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DAMPING = 0.01 +LISSA_ITER = 100000 + + +def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)): + model = construct_rn9().to(DEVICE) + + # Get a single checkpoint (first model_id and last epoch). + model.load_state_dict( + torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu") + ) + model.eval() + + dataloader_fn = get_cifar10_dataloader + train_loader = dataloader_fn( + batch_size=512, split="train", shuffle=False, subsample=True + ) + query_loader = dataloader_fn( + batch_size=1, split="valid", shuffle=False, indices=eval_idxs + ) + + analog = AnaLog(project="test", config="examples/cifar/config.yaml") + + # Gradient & Hessian logging + analog.watch(model) + analog_kwargs = {"log": ["grad"], "hessian": True, "save": True} + + id_gen = DataIDGenerator() + for inputs, targets in train_loader: + data_id = id_gen(inputs) + with analog(data_id=data_id, **analog_kwargs): + inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) + model.zero_grad() + outs = model(inputs) + loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") + loss.backward() + analog.finalize() + + # Influence Analysis + log_loader = analog.build_log_dataloader() + + from analog.analysis import InfluenceFunction + + analog.add_analysis({"influence": InfluenceFunction}) + query_iter = iter(query_loader) + with analog(log=["grad"], test=True) as al: + test_input, test_target = next(query_iter) + test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) + model.zero_grad() + test_out = model(test_input) + test_loss = torch.nn.functional.cross_entropy( + test_out, test_target, reduction="sum" + ) + test_loss.backward() + test_log = al.get_log() + start = time.time() + if_scores = analog.influence.compute_influence_all(test_log, log_loader) + + # Save + if_scores = if_scores.numpy().tolist() + torch.save(if_scores, "examples/cifar_influence/if_analog.pt") + print("Computation time:", time.time() - start) + print(sorted(if_scores)[:10], sorted(if_scores)[-10:]) + + +if __name__ == "__main__": + single_checkpoint_influence(data_name="cifar10") diff --git a/examples/cifar_influence/compute_influences_pca.py b/examples/cifar_influence/compute_influences_pca.py new file mode 100644 index 00000000..ebe92e10 --- /dev/null +++ b/examples/cifar_influence/compute_influences_pca.py @@ -0,0 +1,85 @@ +import time +import torch + + +from analog import AnaLog +from analog.utils import DataIDGenerator +from analog.analysis import InfluenceFunction + +from train import ( + get_cifar10_dataloader, + construct_rn9, +) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DAMPING = 0.01 +LISSA_ITER = 100000 + + +def single_checkpoint_influence(data_name="cifar10", eval_idxs=(0,)): + model = construct_rn9().to(DEVICE) + + # Get a single checkpoint (first model_id and last epoch). + model.load_state_dict( + torch.load(f"checkpoints/{data_name}_0_epoch_23.pt", map_location="cpu") + ) + model.eval() + + dataloader_fn = get_cifar10_dataloader + train_loader = dataloader_fn( + batch_size=512, split="train", shuffle=False, subsample=True + ) + query_loader = dataloader_fn( + batch_size=1, split="valid", shuffle=False, indices=eval_idxs + ) + + analog = AnaLog(project="test", config="examples/cifar_influence/config.yaml") + + # Gradient & Hessian logging + analog.watch(model) + analog_kwargs = {"log": ["grad"], "hessian": True, "save": False} + + id_gen = DataIDGenerator() + for epoch in range(2): + for inputs, targets in train_loader: + data_id = id_gen(inputs) + with analog(data_id=data_id, **analog_kwargs): + inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) + model.zero_grad() + outs = model(inputs) + loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") + loss.backward() + analog.finalize() + if epoch == 0: + analog_kwargs.update({"save": True}) + analog.add_lora(model, parameter_sharing=False) + + # Influence Analysis + log_loader = analog.build_log_dataloader() + + from analog.analysis import InfluenceFunction + + analog.add_analysis({"influence": InfluenceFunction}) + query_iter = iter(query_loader) + with analog(log=["grad"], test=True) as al: + test_input, test_target = next(query_iter) + test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) + model.zero_grad() + test_out = model(test_input) + test_loss = torch.nn.functional.cross_entropy( + test_out, test_target, reduction="sum" + ) + test_loss.backward() + test_log = al.get_log() + start = time.time() + if_scores = analog.influence.compute_influence_all(test_log, log_loader) + + # Save + if_scores = if_scores.numpy().tolist() + torch.save(if_scores, "examples/cifar_influence/if_analog_pca.pt") + print("Computation time:", time.time() - start) + print(sorted(if_scores)[:10], sorted(if_scores)[-10:]) + + +if __name__ == "__main__": + single_checkpoint_influence(data_name="cifar10") diff --git a/examples/cifar_influence/config.yaml b/examples/cifar_influence/config.yaml new file mode 100644 index 00000000..663be1c4 --- /dev/null +++ b/examples/cifar_influence/config.yaml @@ -0,0 +1,5 @@ +storage: + type: default + log_dir: "./analog/log/lora" +lora: + init: pca diff --git a/examples/cifar_influence/train.py b/examples/cifar_influence/train.py new file mode 100644 index 00000000..ab6a6064 --- /dev/null +++ b/examples/cifar_influence/train.py @@ -0,0 +1,104 @@ +# We use an example from the TRAK repository: +# https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb. + + +import os +from pathlib import Path +from tqdm.auto import tqdm +import numpy as np +import torch +from torch.cuda.amp import GradScaler, autocast +from torch.nn import CrossEntropyLoss +from torch.optim import SGD, lr_scheduler + +from utils import get_cifar10_dataloader, construct_rn9 +from utils import set_seed + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def train( + model, + loader, + lr=0.4, + epochs=24, + momentum=0.9, + weight_decay=5e-4, + lr_peak_epoch=5, + save_name="default", + model_id=0, + save=True, +): + opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) + iters_per_epoch = len(loader) + # Cyclic LR with single triangle + lr_schedule = np.interp( + np.arange((epochs + 1) * iters_per_epoch), + [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch], + [0, 1, 0], + ) + scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__) + scaler = GradScaler() + loss_fn = CrossEntropyLoss() + + for epoch in tqdm(range(epochs)): + set_seed(model_id * 10_061 + epoch + 1) + for images, labels in loader: + images = images.to(DEVICE) + labels = labels.to(DEVICE) + opt.zero_grad(set_to_none=True) + with autocast(): + out = model(images) + loss = loss_fn(out, labels) + + scaler.scale(loss).backward() + scaler.step(opt) + scaler.update() + scheduler.step() + + if save: + torch.save( + model.state_dict(), + f"checkpoints/{save_name}_{model_id}_epoch_{epochs-1}.pt", + ) + + return model + + +def main(dataset="cifar10"): + os.makedirs("checkpoints", exist_ok=True) + + if dataset == "cifar10": + train_loader = get_cifar10_dataloader( + batch_size=512, split="train", shuffle=True, subsample=True + ) + valid_loader = get_cifar10_dataloader( + batch_size=512, split="val", shuffle=False + ) + else: + raise NotImplementedError + + # you can modify the for loop below to train more models + for model_id in tqdm(range(1), desc="Training models.."): + model = construct_rn9().to(memory_format=torch.channels_last).to(DEVICE) + + model = train(model, train_loader, save_name=dataset, model_id=model_id) + + model = model.eval() + + model.eval() + with torch.no_grad(): + total_correct, total_num = 0.0, 0.0 + for images, labels in tqdm(valid_loader): + images = images.to(DEVICE) + labels = labels.to(DEVICE) + with autocast(): + out = model(images) + total_correct += out.argmax(1).eq(labels).sum().cpu().item() + total_num += images.shape[0] + + print(f"Accuracy: {total_correct / total_num * 100:.1f}%") + + +if __name__ == "__main__": + main() diff --git a/examples/cifar_influence/utils.py b/examples/cifar_influence/utils.py new file mode 100644 index 00000000..6f225f4f --- /dev/null +++ b/examples/cifar_influence/utils.py @@ -0,0 +1,130 @@ +# We use an example from the TRAK repository: +# https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb. + + +import torch +import torchvision +import numpy as np +import random + + +def set_seed(seed): + seed = int(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +class Mul(torch.nn.Module): + def __init__(self, weight): + super(Mul, self).__init__() + self.weight = weight + + def forward(self, x): + return x * self.weight + + +class Flatten(torch.nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class Residual(torch.nn.Module): + def __init__(self, module): + super(Residual, self).__init__() + self.module = module + + def forward(self, x): + return x + self.module(x) + + +def construct_rn9(num_classes=10, seed=0): + set_seed(seed) + + def conv_bn( + channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1 + ): + return torch.nn.Sequential( + torch.nn.Conv2d( + channels_in, + channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + torch.nn.BatchNorm2d(channels_out), + torch.nn.ReLU(inplace=True), + ) + + model = torch.nn.Sequential( + conv_bn(3, 64, kernel_size=3, stride=1, padding=1), + conv_bn(64, 128, kernel_size=5, stride=2, padding=2), + Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))), + conv_bn(128, 256, kernel_size=3, stride=1, padding=1), + torch.nn.MaxPool2d(2), + Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))), + conv_bn(256, 128, kernel_size=3, stride=1, padding=0), + torch.nn.AdaptiveMaxPool2d((1, 1)), + Flatten(), + torch.nn.Linear(128, num_classes, bias=False), + Mul(0.2), + ) + return model + + +def get_cifar10_dataloader( + batch_size=256, + num_workers=8, + split="train", + shuffle=False, + augment=True, + drop_last=False, + subsample=False, + indices=None, +): + if augment: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.RandomAffine(0), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201) + ), + ] + ) + else: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201) + ), + ] + ) + + is_train = split == "train" + dataset = torchvision.datasets.CIFAR10( + root="/tmp/cifar/", download=True, train=is_train, transform=transforms + ) + + if subsample and split == "train" and indices is None: + dataset = torch.utils.data.Subset(dataset, np.arange(6_000)) + + if indices is not None: + if subsample and split == "train": + print("Overriding `subsample` argument as `indices` was provided.") + dataset = torch.utils.data.Subset(dataset, indices) + + loader = torch.utils.data.DataLoader( + dataset=dataset, + shuffle=shuffle, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last, + ) + + return loader diff --git a/tests/logger/test_conv2d.py b/tests/logger/test_conv2d.py index 20e2d24c..3a58cd15 100644 --- a/tests/logger/test_conv2d.py +++ b/tests/logger/test_conv2d.py @@ -85,7 +85,7 @@ def compute_loss_func(_params, _buffers, _batch): for module_name in analog_grads_dict: analog_grad = analog_grads_dict[module_name] - func_grad = grads_dict[module_name + ".weight"] + func_grad = grads_dict[module_name + ".weight"].view(analog_grad.shape) self.assertTrue(torch.allclose(analog_grad, func_grad, atol=1e-6))