-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sensitivity analysis tool for layer-wise FIT and Hessian trace #592
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
torchao/quantization/prototype/mixed_precision/scripts/FIT.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
import transformers | ||
from datasets import load_dataset | ||
import random | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
|
||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") | ||
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") | ||
|
||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") | ||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc | ||
|
||
def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_layers): | ||
|
||
# store the history of trace for each layer | ||
estimated_history=[] | ||
|
||
# store the history of mean trace for each layer | ||
estimated_mean = [[] for _ in range(num_layers)] | ||
trace = [0.] * num_layers | ||
|
||
|
||
for iteration in range(maxIter): | ||
print("iteration: ",iteration) | ||
trace_tmp = [0.] * num_layers | ||
|
||
for i in tqdm(range(nsamples)): | ||
inputs, targets = data[i] | ||
inputs = inputs.to(device) | ||
targets = targets.to(device) | ||
model.zero_grad() | ||
outputs = model(inputs) | ||
logits = outputs.logits | ||
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) | ||
|
||
grads = torch.autograd.grad(loss, model.parameters()) | ||
|
||
# Trace(Fisher Information Matrix) is calculated by the sum of the square of the gradient | ||
for layerid in range(num_layers): | ||
for (name, _), grad in zip(model.named_parameters(), grads): | ||
if "."+str(layerid)+"." in name and ("self_attn" in name or "mlp" in name): | ||
trace_tmp[layerid] += torch.sum(grad * grad).item() | ||
|
||
# clean cache | ||
model.zero_grad() | ||
del grads | ||
torch.cuda.empty_cache() | ||
|
||
# calculate the mean of the trace on the calibration dataset | ||
for t in range(num_layers): | ||
trace[t] = trace_tmp[t] / float(nsamples) | ||
estimated_mean[t].append(trace[t]) | ||
|
||
print("trace:",trace) | ||
estimated_history.append(trace) | ||
|
||
F_average = np.array([np.mean(i) for i in estimated_mean]) | ||
return F_average, estimated_mean, estimated_history | ||
|
||
def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M | ||
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# load calibration dataset | ||
seed = 0 | ||
trainloader, testloader = get_wikitext2(nsamples, seed, max_seqlen, tokenizer) | ||
|
||
F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) | ||
print("Iteration Done") | ||
print("avg_trace:", F_average) | ||
print("estimated_mean:", estimated_mean) | ||
print("estimated_history:", estimated_history) | ||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Your CLI description.') | ||
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') | ||
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') | ||
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT') | ||
parser.add_argument('--num_layers', type=int, default=32, help='The number of layers to calculate FIT.') | ||
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') | ||
args = parser.parse_args() | ||
main(args.max_seqlen, args.checkpoint, args.nsamples, args.maxIter, args.num_layers) |
151 changes: 151 additions & 0 deletions
151
torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
import transformers | ||
from datasets import load_dataset | ||
import random | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
from torch.autograd.functional import hvp | ||
|
||
def group_product(xs, ys): | ||
return [torch.sum(x * y) for (x, y) in zip(xs, ys)] | ||
|
||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") | ||
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") | ||
|
||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") | ||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc.input_ids | ||
|
||
def dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion): | ||
model.zero_grad() | ||
THv = [torch.zeros(p.size()).to(device) for p in params] # accumulate result | ||
|
||
# Freeze all the parameters in the model | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
# Unfreeze the parameters of attention and MLP layers in layer 0 | ||
layer_ = model.model.layers[layerid] | ||
for param in layer_.self_attn.parameters(): | ||
param.requires_grad = True | ||
for param in layer_.mlp.parameters(): | ||
param.requires_grad = True | ||
|
||
for i in tqdm(range(nsamples)): | ||
torch.cuda.empty_cache() | ||
inputs, labels = data[i] | ||
inputs = inputs.to(device) | ||
labels = labels.to(device) | ||
# if use testloader: | ||
# inputs = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) | ||
# labels = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) | ||
model.zero_grad() | ||
outputs = model(inputs) | ||
logits = outputs.logits | ||
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) | ||
|
||
# get the first order gradients | ||
grads = torch.autograd.grad(loss, params, create_graph=True, only_inputs=True) | ||
|
||
# calculate Hessian vector product via Jac-vector product | ||
Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=False) | ||
|
||
THv = [THv1 + Hv1 + 0.0 for THv1, Hv1 in zip(THv, Hv)] | ||
|
||
# clean cache | ||
model.zero_grad() | ||
del Hv | ||
del grads | ||
torch.cuda.empty_cache() | ||
|
||
THv = [THv1 / float(nsamples) for THv1 in THv] | ||
return THv | ||
|
||
def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqlen, criterion): | ||
vhv_c_history = [] | ||
trace_history = [] | ||
trace = 0. | ||
|
||
for i in range(maxIter): | ||
print("iteration: ",i) | ||
|
||
# generate Rademacher random variables | ||
v = [ | ||
torch.randint_like(p, high=2, device=device) | ||
for p in params | ||
] | ||
|
||
for v_i in v: | ||
v_i[v_i == 0] = -1 | ||
|
||
# calculate Hessian vector product | ||
Hv = dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion) | ||
|
||
vHv = group_product(Hv, v) | ||
|
||
vHv_c = np.array([i.cpu().numpy() for i in vHv]) | ||
|
||
vhv_c_history.append(vHv_c) | ||
|
||
trace = np.sum(vHv_c) | ||
|
||
trace_history.append(trace) | ||
print("trace,", trace) | ||
print("trace_history,", trace_history) | ||
print("vhv_c_history,", vhv_c_history) | ||
|
||
return np.mean(trace_history) | ||
|
||
|
||
def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
# to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error | ||
with sdpa_kernel(SDPBackend.MATH): | ||
|
||
# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M | ||
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) | ||
model = model.cuda() | ||
model.eval() | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# load calibration dataset | ||
seed = 0 | ||
trainloader, testloader = get_wikitext2(128, seed, 2048, tokenizer) | ||
|
||
# calculate Hessian for only one layer each time | ||
params=[] | ||
layer_ = model.model.layers[layer_id] | ||
for param in layer_.self_attn.parameters(): | ||
params.append(param) | ||
for param in layer_.mlp.parameters(): | ||
params.append(param) | ||
|
||
trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion) | ||
print("The trace of layer " + str(layer_id) + " is", trace) | ||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Your CLI description.') | ||
parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian') | ||
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') | ||
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') | ||
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') | ||
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') | ||
args = parser.parse_args() | ||
main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you use lowercase names for these 3 files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have renamed them