diff --git a/analog/analog.py b/analog/analog.py index 05b295e6..a4ac511f 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -6,7 +6,7 @@ from analog.config import Config from analog.constants import FORWARD, BACKWARD, GRAD, LOG_TYPES from analog.logging import LoggingHandler -from analog.storage import DefaultStorageHandler, MongoDBStorageHandler +from analog.storage import DefaultStorageHandler from analog.hessian import RawHessianHandler, KFACHessianHandler from analog.analysis import AnalysisBase from analog.lora import LoRAHandler diff --git a/examples/bert_influence/compute_influence.py b/examples/bert_influence/compute_influence.py index 735f2f89..a5ac7841 100644 --- a/examples/bert_influence/compute_influence.py +++ b/examples/bert_influence/compute_influence.py @@ -1,4 +1,5 @@ import time +import argparse import torch import torch.nn.functional as F @@ -10,46 +11,43 @@ from pipeline import construct_model, get_loaders from utils import set_seed +parser = argparse.ArgumentParser("GLUE Influence Analysis") +parser.add_argument("--data_name", type=str, default="sst2") +parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0]) +parser.add_argument("--damping", type=float, default=1e-5) +args = parser.parse_args() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") set_seed(0) -def single_checkpoint_influence( - data_name: str, - model_name: str, - ckpt_path: str, - save_name: str, - train_batch_size=4, - test_batch_size=4, - train_indices=None, - test_indices=None, -): - # model - model = construct_model(model_name, ckpt_path) - model.to(DEVICE) - model.eval() +# model +model = construct_model(args.data_name, ckpt_path) +model.load_state_dict( + torch.load(f"files/checkpoints/0/{args.data_name}_epoch_3.pt", map_location="cpu") +) +model.to(DEVICE) +model.eval() - # data - _, eval_train_loader, test_loader = get_loaders(data_name=data_name) +# data +_, eval_train_loader, test_loader = get_loaders(data_name=data_name) - # Set-up - analog = AnaLog( - project="test", - config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml", - ) +# Set-up +analog = AnaLog(project="test") - # Hessian logging - analog.watch(model, type_filter=[torch.nn.Linear], lora=False) - id_gen = DataIDGenerator() +# Hessian logging +analog.watch(model) +analog_kwargs = {"log": [], "hessian": True, "save": False} +id_gen = DataIDGenerator() +for epoch in range(2): for batch in tqdm(eval_train_loader, desc="Hessian logging"): data_id = id_gen(batch["input_ids"]) - with analog(data_id=data_id, log=[], save=False): - inputs = ( - batch["input_ids"].to(DEVICE), - batch["token_type_ids"].to(DEVICE), - batch["attention_mask"].to(DEVICE), - ) + inputs = ( + batch["input_ids"].to(DEVICE), + batch["token_type_ids"].to(DEVICE), + batch["attention_mask"].to(DEVICE), + ) + with analog(data_id=data_id, mask=inputs[-1], **analog_kwargs): model.zero_grad() outputs = model(*inputs) @@ -58,75 +56,40 @@ def single_checkpoint_influence( loss = F.cross_entropy(logits, labels, reduction="sum", ignore_index=-100) loss.backward() analog.finalize() + if epoch == 0: + analog_kwargs.update({"save": True, "log": ["grad"]}) + analog.add_lora(model, parameter_sharing=False) + +# Compute influence +log_loader = analog.build_log_dataloader() +analog.add_analysis({"influence": InfluenceFunction}) +test_iter = iter(test_loader) +with analog(log=["grad"], test=True) as al: + test_batch = next(test_iter) + test_inputs = ( + test_batch["input_ids"].to(DEVICE), + test_batch["token_type_ids"].to(DEVICE), + test_batch["attention_mask"].to(DEVICE), + ) + test_target = test_batch["labels"].to(DEVICE) + model.zero_grad() + test_outputs = model(*test_inputs) + + test_logits = test_outputs.view(-1, outputs.shape[-1]) + test_labels = test_batch["labels"].view(-1).to(DEVICE) + test_loss = F.cross_entropy( + test_logits, + test_labels, + reduction="sum", + ignore_index=-100, + ) + test_loss.backward() - # Compressed gradient logging - analog.add_lora(model, parameter_sharing=False) - for batch in tqdm(eval_train_loader, desc="Compressed gradient logging"): - data_id = id_gen(batch["input_ids"]) - with analog(data_id=data_id, log=["grad"], save=True): - inputs = ( - batch["input_ids"].to(DEVICE), - batch["token_type_ids"].to(DEVICE), - batch["attention_mask"].to(DEVICE), - ) - model.zero_grad() - outputs = model(*inputs) - - logits = outputs.view(-1, outputs.shape[-1]) - labels = batch["labels"].view(-1).to(DEVICE) - loss = F.cross_entropy( - logits, - labels, - reduction="sum", - ignore_index=-100, - ) - loss.backward() - analog.finalize() - - # Compute influence - log_loader = analog.build_log_dataloader() - analog.add_analysis({"influence": InfluenceFunction}) - test_iter = iter(test_loader) - with analog(log=["grad"], test=True) as al: - test_batch = next(test_iter) - test_inputs = ( - test_batch["input_ids"].to(DEVICE), - test_batch["token_type_ids"].to(DEVICE), - test_batch["attention_mask"].to(DEVICE), - ) - test_target = test_batch["labels"].to(DEVICE) - model.zero_grad() - test_outputs = model(*test_inputs) - - test_logits = test_outputs.view(-1, outputs.shape[-1]) - test_labels = test_batch["labels"].view(-1).to(DEVICE) - test_loss = F.cross_entropy( - test_logits, - test_labels, - reduction="sum", - ignore_index=-100, - ) - test_loss.backward() - - test_log = al.get_log() - - start = time.time() - if_scores = analog.influence.compute_influence_all(test_log, log_loader) - print("Computation time:", time.time() - start) - - # Save - torch.save(if_scores, "if_analog.pt") - + test_log = al.get_log() -def main(): - data_name = "sst2" - model_name = "bert-base-uncased" - ckpt_path = "/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/files/checkpoints/0/sst2_epoch_3.pt" - save_name = "sst2_score_if.pt" +start = time.time() +if_scores = analog.influence.compute_influence_all(test_log, log_loader) +print("Computation time:", time.time() - start) - single_checkpoint_influence( - data_name=data_name, - model_name=model_name, - ckpt_path=ckpt_path, - save_name=save_name, - ) +# Save +torch.save(if_scores, "if_analog.pt") diff --git a/examples/mnist_influence/compute_influences_pca.py b/examples/mnist_influence/compute_influences_pca.py index 70598afb..a555a2a7 100644 --- a/examples/mnist_influence/compute_influences_pca.py +++ b/examples/mnist_influence/compute_influences_pca.py @@ -13,14 +13,14 @@ ) parser = argparse.ArgumentParser("MNIST Influence Analysis") -parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist") +parser.add_argument("--data_name", type=str, default="mnist", help="mnist or fmnist") parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0]) parser.add_argument("--damping", type=float, default=1e-5) args = parser.parse_args() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model = construct_mlp().to(DEVICE) +model = construct_mlp(data_name=args.data_name).to(DEVICE) model.load_state_dict( torch.load(f"checkpoints/{args.data}_0_epoch_9.pt", map_location="cpu") )