Skip to content

Commit

Permalink
working BERT GLUE train
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 25, 2023
1 parent 8e84706 commit 7dbddc0
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 4 deletions.
125 changes: 125 additions & 0 deletions examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time

import torch
import torch.nn.functional as F
from analog import AnaLog
from analog.analysis import InfluenceFunction
from analog.utils import DataIDGenerator
from tqdm import tqdm

from pipeline import construct_model, get_loaders
from utils import set_seed


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(0)


def single_checkpoint_influence(
data_name: str,
model_name: str,
ckpt_path: str,
save_name: str,
train_batch_size=4,
test_batch_size=4,
train_indices=None,
test_indices=None,
):
# model
model = construct_model(model_name, ckpt_path)
model.to(DEVICE)
model.eval()

# data
_, eval_train_loader, test_loader = get_loaders(data_name=data_name)

# Set-up
analog = AnaLog(project="test", config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml")

# Hessian logging
analog.watch(model, type_filter=[torch.nn.Linear], lora=False)
id_gen = DataIDGenerator()
for batch in tqdm(eval_train_loader, desc="Hessian logging"):
data_id = id_gen(batch["input_ids"])
with analog(data_id=data_id, log=[], save=False):
inputs = (
batch["input_ids"].to(DEVICE),
batch["token_type_ids"].to(DEVICE),
batch["attention_mask"].to(DEVICE),
)
model.zero_grad()
outputs = model(*inputs)

logits = outputs.view(-1, outputs.shape[-1])
labels = batch["labels"].view(-1).to(DEVICE)
loss = F.cross_entropy(
logits, labels, reduction="sum", ignore_index=-100
)
loss.backward()
analog.finalize()

# Compressed gradient logging
analog.add_lora(model, parameter_sharing=False)
for batch in tqdm(eval_train_loader, desc="Compressed gradient logging"):
data_id = id_gen(batch["input_ids"])
with analog(data_id=data_id, log=["grad"], save=True):
inputs = (
batch["input_ids"].to(DEVICE),
batch["token_type_ids"].to(DEVICE),
batch["attention_mask"].to(DEVICE),
)
model.zero_grad()
outputs = model(*inputs)

logits = outputs.view(-1, outputs.shape[-1])
labels = batch["labels"].view(-1).to(DEVICE)
loss = F.cross_entropy(
logits, labels, reduction="sum", ignore_index=-100,
)
loss.backward()
analog.finalize()

# Compute influence
log_loader = analog.build_log_dataloader()
analog.add_analysis({"influence": InfluenceFunction})
test_iter = iter(test_loader)
with analog(log=["grad"], test=True) as al:
test_batch = next(test_iter)
test_inputs = (
test_batch["input_ids"].to(DEVICE),
test_batch["token_type_ids"].to(DEVICE),
test_batch["attention_mask"].to(DEVICE),
)
test_target = test_batch["labels"].to(DEVICE)
model.zero_grad()
test_outputs = model(*test_inputs)

test_logits = test_outputs.view(-1, outputs.shape[-1])
test_labels = test_batch["labels"].view(-1).to(DEVICE)
test_loss = F.cross_entropy(
test_logits, test_labels, reduction="sum", ignore_index=-100,
)
test_loss.backward()

test_log = al.get_log()

start = time.time()
if_scores = analog.influence.compute_influence_all(test_log, log_loader)
print("Computation time:", time.time() - start)

# Save
torch.save(if_scores, "if_analog.pt")


def main():
data_name = "sst2"
model_name = "bert-base-uncased"
ckpt_path = "/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/files/checkpoints/0/sst2_epoch_3.pt"
save_name = "sst2_score_if.pt"

single_checkpoint_influence(
data_name=data_name,
model_name=model_name,
ckpt_path=ckpt_path,
save_name=save_name,
)
Empty file.
116 changes: 116 additions & 0 deletions examples/bert_influence/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import time
import argparse
from typing import Optional, Tuple

from tqdm import tqdm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
from torch.nn import CrossEntropyLoss

from utils import clear_gpu_cache, set_seed, construct_model, get_loaders


parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--num_train", type=int, default=1)
args = parser.parse_args()

os.makedirs("files/", exist_ok=True)
os.makedirs("files/checkpoints", exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, _, valid_loader = get_loaders(data_name=args.data_name)
model = construct_model(data_name=args.data_name).to(device)

def train(
model: nn.Module,
loader: torch.utils.data.DataLoader,
model_id: int = 0,
lr: float = 2e-5,
weight_decay: float = 0.0,
save_name: Optional[str] = None,
) -> nn.Module:
save = save_name is not None
if save:
os.makedirs(f"files/checkpoints/{model_id}", exist_ok=True)
torch.save(
model.state_dict(),
f"files/checkpoints/{model_id}/{save_name}_epoch_0.pt",
)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_fn = CrossEntropyLoss()
epochs = 3

num_update_steps_per_epoch = math.ceil(len(loader))
assert math.ceil(len(loader)) == num_update_steps_per_epoch

model.train()
num_iter = 0
for epoch in range(epochs):
for batch in tqdm(loader):
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(
batch["input_ids"], batch["token_type_ids"], batch["attention_mask"]
)
loss = loss_fn(outputs, batch["labels"])
loss.backward()
optimizer.step()
num_iter += 1

if save:
torch.save(
model.state_dict(),
f"files/checkpoints/{model_id}/{save_name}_epoch_{epoch + 1}.pt",
)
return model


def model_evaluate(
model: nn.Module, loader: torch.utils.data.DataLoader
) -> Tuple[float, float]:
model.eval()
# Task name does not really matter here.
metric = evaluate.load("glue", "qnli")
total_loss, total_num = 0.0, 0.0
for step, batch in enumerate(loader):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(
batch["input_ids"], batch["token_type_ids"], batch["attention_mask"]
)
total_loss += (
F.cross_entropy(outputs, batch["labels"], reduction="sum").cpu().item()
)
total_num += batch["input_ids"].shape[0]
predictions = outputs.argmax(dim=-1)
metric.add_batch(
predictions=predictions,
references=batch["labels"],
)
eval_metric = metric.compute()
return total_loss / total_num, eval_metric["accuracy"]


for i in range(args.num_train):
print(f"Training {i}th model ...")
start_time = time.time()

set_seed(i)

train(
model=model,
loader=train_loader,
model_id=i,
save_name=args.data_name,
)

_, valid_acc = model_evaluate(model=model, loader=valid_loader)
print(f"Validation Accuracy: {valid_acc}")
del model
clear_gpu_cache()
print(f"Took {time.time() - start_time} seconds.")
43 changes: 39 additions & 4 deletions examples/bert_influence/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import gc
import os
import random
import struct

import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
Expand All @@ -24,6 +30,29 @@
}


def set_seed(seed: int) -> None:
"""Set random seed for reproducibility."""
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


def reset_seed() -> None:
"""Reset the seed to have randomized experiments."""
rng_seed = struct.unpack("I", os.urandom(4))[0]
set_seed(rng_seed)


def clear_gpu_cache() -> None:
"""Perform garbage collection and empty GPU cache reserved by Pytorch."""
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()


class SequenceClassificationModel(nn.Module):
def __init__(self, data_name: str) -> None:
super().__init__()
Expand Down Expand Up @@ -55,8 +84,14 @@ def forward(
).logits


def construct_model(data_name: str) -> nn.Module:
return SequenceClassificationModel(data_name)
def construct_model(
data_name: str, ckpt_path: Union[None, str] = None
) -> nn.Module:
model = SequenceClassificationModel(data_name)
if ckpt_path is not None:
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
print(f"Loaded model from {ckpt_path}.")
return model


def get_loaders(
Expand Down Expand Up @@ -155,4 +190,4 @@ def preprocess_function(examples):
batch_size=batch_size,
shuffle=split == "train",
collate_fn=default_data_collator,
)
)

0 comments on commit 7dbddc0

Please sign in to comment.