Skip to content

Commit

Permalink
remove unused storage handler
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 25, 2023
1 parent 4da6210 commit bfddd80
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 102 deletions.
2 changes: 1 addition & 1 deletion analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from analog.config import Config
from analog.constants import FORWARD, BACKWARD, GRAD, LOG_TYPES
from analog.logging import LoggingHandler
from analog.storage import DefaultStorageHandler, MongoDBStorageHandler
from analog.storage import DefaultStorageHandler
from analog.hessian import RawHessianHandler, KFACHessianHandler
from analog.analysis import AnalysisBase
from analog.lora import LoRAHandler
Expand Down
161 changes: 62 additions & 99 deletions examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import argparse

import torch
import torch.nn.functional as F
Expand All @@ -10,46 +11,43 @@
from pipeline import construct_model, get_loaders
from utils import set_seed

parser = argparse.ArgumentParser("GLUE Influence Analysis")
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
parser.add_argument("--damping", type=float, default=1e-5)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(0)


def single_checkpoint_influence(
data_name: str,
model_name: str,
ckpt_path: str,
save_name: str,
train_batch_size=4,
test_batch_size=4,
train_indices=None,
test_indices=None,
):
# model
model = construct_model(model_name, ckpt_path)
model.to(DEVICE)
model.eval()
# model
model = construct_model(args.data_name, ckpt_path)
model.load_state_dict(
torch.load(f"files/checkpoints/0/{args.data_name}_epoch_3.pt", map_location="cpu")
)
model.to(DEVICE)
model.eval()

# data
_, eval_train_loader, test_loader = get_loaders(data_name=data_name)
# data
_, eval_train_loader, test_loader = get_loaders(data_name=data_name)

# Set-up
analog = AnaLog(
project="test",
config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml",
)
# Set-up
analog = AnaLog(project="test")

# Hessian logging
analog.watch(model, type_filter=[torch.nn.Linear], lora=False)
id_gen = DataIDGenerator()
# Hessian logging
analog.watch(model)
analog_kwargs = {"log": [], "hessian": True, "save": False}
id_gen = DataIDGenerator()
for epoch in range(2):
for batch in tqdm(eval_train_loader, desc="Hessian logging"):
data_id = id_gen(batch["input_ids"])
with analog(data_id=data_id, log=[], save=False):
inputs = (
batch["input_ids"].to(DEVICE),
batch["token_type_ids"].to(DEVICE),
batch["attention_mask"].to(DEVICE),
)
inputs = (
batch["input_ids"].to(DEVICE),
batch["token_type_ids"].to(DEVICE),
batch["attention_mask"].to(DEVICE),
)
with analog(data_id=data_id, mask=inputs[-1], **analog_kwargs):
model.zero_grad()
outputs = model(*inputs)

Expand All @@ -58,75 +56,40 @@ def single_checkpoint_influence(
loss = F.cross_entropy(logits, labels, reduction="sum", ignore_index=-100)
loss.backward()
analog.finalize()
if epoch == 0:
analog_kwargs.update({"save": True, "log": ["grad"]})
analog.add_lora(model, parameter_sharing=False)

# Compute influence
log_loader = analog.build_log_dataloader()
analog.add_analysis({"influence": InfluenceFunction})
test_iter = iter(test_loader)
with analog(log=["grad"], test=True) as al:
test_batch = next(test_iter)
test_inputs = (
test_batch["input_ids"].to(DEVICE),
test_batch["token_type_ids"].to(DEVICE),
test_batch["attention_mask"].to(DEVICE),
)
test_target = test_batch["labels"].to(DEVICE)
model.zero_grad()
test_outputs = model(*test_inputs)

test_logits = test_outputs.view(-1, outputs.shape[-1])
test_labels = test_batch["labels"].view(-1).to(DEVICE)
test_loss = F.cross_entropy(
test_logits,
test_labels,
reduction="sum",
ignore_index=-100,
)
test_loss.backward()

# Compressed gradient logging
analog.add_lora(model, parameter_sharing=False)
for batch in tqdm(eval_train_loader, desc="Compressed gradient logging"):
data_id = id_gen(batch["input_ids"])
with analog(data_id=data_id, log=["grad"], save=True):
inputs = (
batch["input_ids"].to(DEVICE),
batch["token_type_ids"].to(DEVICE),
batch["attention_mask"].to(DEVICE),
)
model.zero_grad()
outputs = model(*inputs)

logits = outputs.view(-1, outputs.shape[-1])
labels = batch["labels"].view(-1).to(DEVICE)
loss = F.cross_entropy(
logits,
labels,
reduction="sum",
ignore_index=-100,
)
loss.backward()
analog.finalize()

# Compute influence
log_loader = analog.build_log_dataloader()
analog.add_analysis({"influence": InfluenceFunction})
test_iter = iter(test_loader)
with analog(log=["grad"], test=True) as al:
test_batch = next(test_iter)
test_inputs = (
test_batch["input_ids"].to(DEVICE),
test_batch["token_type_ids"].to(DEVICE),
test_batch["attention_mask"].to(DEVICE),
)
test_target = test_batch["labels"].to(DEVICE)
model.zero_grad()
test_outputs = model(*test_inputs)

test_logits = test_outputs.view(-1, outputs.shape[-1])
test_labels = test_batch["labels"].view(-1).to(DEVICE)
test_loss = F.cross_entropy(
test_logits,
test_labels,
reduction="sum",
ignore_index=-100,
)
test_loss.backward()

test_log = al.get_log()

start = time.time()
if_scores = analog.influence.compute_influence_all(test_log, log_loader)
print("Computation time:", time.time() - start)

# Save
torch.save(if_scores, "if_analog.pt")

test_log = al.get_log()

def main():
data_name = "sst2"
model_name = "bert-base-uncased"
ckpt_path = "/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/files/checkpoints/0/sst2_epoch_3.pt"
save_name = "sst2_score_if.pt"
start = time.time()
if_scores = analog.influence.compute_influence_all(test_log, log_loader)
print("Computation time:", time.time() - start)

single_checkpoint_influence(
data_name=data_name,
model_name=model_name,
ckpt_path=ckpt_path,
save_name=save_name,
)
# Save
torch.save(if_scores, "if_analog.pt")
4 changes: 2 additions & 2 deletions examples/mnist_influence/compute_influences_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
)

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--data_name", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
parser.add_argument("--damping", type=float, default=1e-5)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = construct_mlp().to(DEVICE)
model = construct_mlp(data_name=args.data_name).to(DEVICE)
model.load_state_dict(
torch.load(f"checkpoints/{args.data}_0_epoch_9.pt", map_location="cpu")
)
Expand Down

0 comments on commit bfddd80

Please sign in to comment.