From e9c79523e8f229afc54d3ef0819f49bb7cff247b Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Fri, 2 Aug 2024 11:23:32 -0700 Subject: [PATCH 1/2] add FIT and Hessian --- .../prototype/mixed_precision/scripts/FIT.py | 105 +++++++++++ .../mixed_precision/scripts/Hessian_grad.py | 151 ++++++++++++++++ .../mixed_precision/scripts/Hessian_vhp.py | 164 ++++++++++++++++++ 3 files changed, 420 insertions(+) create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/FIT.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/FIT.py b/torchao/quantization/prototype/mixed_precision/scripts/FIT.py new file mode 100644 index 0000000000..19366a01c9 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/FIT.py @@ -0,0 +1,105 @@ +import torch +import numpy as np +import os +from tqdm import tqdm +import transformers +from datasets import load_dataset +import random +from torch.nn.attention import SDPBackend, sdpa_kernel + +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_layers): + + # store the history of trace for each layer + estimated_history=[] + + # store the history of mean trace for each layer + estimated_mean = [[] for _ in range(num_layers)] + trace = [0.] * num_layers + + + for iteration in range(maxIter): + print("iteration: ",iteration) + trace_tmp = [0.] * num_layers + + for i in tqdm(range(nsamples)): + inputs, targets = data[i] + inputs = inputs.to(device) + targets = targets.to(device) + model.zero_grad() + outputs = model(inputs) + logits = outputs.logits + loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) + + grads = torch.autograd.grad(loss, model.parameters()) + + # Trace(Fisher Information Matrix) is calculated by the sum of the square of the gradient + for layerid in range(num_layers): + for (name, _), grad in zip(model.named_parameters(), grads): + if "."+str(layerid)+"." in name and ("self_attn" in name or "mlp" in name): + trace_tmp[layerid] += torch.sum(grad * grad).item() + + # clean cache + model.zero_grad() + del grads + torch.cuda.empty_cache() + + # calculate the mean of the trace on the calibration dataset + for t in range(num_layers): + trace[t] = trace_tmp[t] / float(nsamples) + estimated_mean[t].append(trace[t]) + + print("trace:",trace) + estimated_history.append(trace) + + F_average = np.array([np.mean(i) for i in estimated_mean]) + return F_average, estimated_mean, estimated_history + +def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M + model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + model = model.to(device) + model.eval() + + criterion = torch.nn.CrossEntropyLoss() + + # load calibration dataset + seed = 0 + trainloader, testloader = get_wikitext2(nsamples, seed, max_seqlen, tokenizer) + + F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) + print("Iteration Done") + print("avg_trace:", F_average) + print("estimated_mean:", estimated_mean) + print("estimated_history:", estimated_history) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') + parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT') + parser.add_argument('--num_layers', type=int, default=32, help='The number of layers to calculate FIT.') + parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') + args = parser.parse_args() + main(args.max_seqlen, args.checkpoint, args.nsamples, args.maxIter, args.num_layers) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py b/torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py new file mode 100644 index 0000000000..bb33cff39f --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py @@ -0,0 +1,151 @@ +import torch +import numpy as np +import os +from tqdm import tqdm +import transformers +from datasets import load_dataset +import random +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.autograd.functional import hvp + +def group_product(xs, ys): + return [torch.sum(x * y) for (x, y) in zip(xs, ys)] + +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc.input_ids + +def dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion): + model.zero_grad() + THv = [torch.zeros(p.size()).to(device) for p in params] # accumulate result + + # Freeze all the parameters in the model + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the parameters of attention and MLP layers in layer 0 + layer_ = model.model.layers[layerid] + for param in layer_.self_attn.parameters(): + param.requires_grad = True + for param in layer_.mlp.parameters(): + param.requires_grad = True + + for i in tqdm(range(nsamples)): + torch.cuda.empty_cache() + inputs, labels = data[i] + inputs = inputs.to(device) + labels = labels.to(device) + # if use testloader: + # inputs = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) + # labels = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) + model.zero_grad() + outputs = model(inputs) + logits = outputs.logits + loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) + + # get the first order gradients + grads = torch.autograd.grad(loss, params, create_graph=True, only_inputs=True) + + # calculate Hessian vector product via Jac-vector product + Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=False) + + THv = [THv1 + Hv1 + 0.0 for THv1, Hv1 in zip(THv, Hv)] + + # clean cache + model.zero_grad() + del Hv + del grads + torch.cuda.empty_cache() + + THv = [THv1 / float(nsamples) for THv1 in THv] + return THv + +def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqlen, criterion): + vhv_c_history = [] + trace_history = [] + trace = 0. + + for i in range(maxIter): + print("iteration: ",i) + + # generate Rademacher random variables + v = [ + torch.randint_like(p, high=2, device=device) + for p in params + ] + + for v_i in v: + v_i[v_i == 0] = -1 + + # calculate Hessian vector product + Hv = dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion) + + vHv = group_product(Hv, v) + + vHv_c = np.array([i.cpu().numpy() for i in vHv]) + + vhv_c_history.append(vHv_c) + + trace = np.sum(vHv_c) + + trace_history.append(trace) + print("trace,", trace) + print("trace_history,", trace_history) + print("vhv_c_history,", vhv_c_history) + + return np.mean(trace_history) + + +def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error + with sdpa_kernel(SDPBackend.MATH): + + # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M + model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + model = model.cuda() + model.eval() + + criterion = torch.nn.CrossEntropyLoss() + + # load calibration dataset + seed = 0 + trainloader, testloader = get_wikitext2(128, seed, 2048, tokenizer) + + # calculate Hessian for only one layer each time + params=[] + layer_ = model.model.layers[layer_id] + for param in layer_.self_attn.parameters(): + params.append(param) + for param in layer_.mlp.parameters(): + params.append(param) + + trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion) + print("The trace of layer " + str(layer_id) + " is", trace) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian') + parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') + parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') + parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') + args = parser.parse_args() + main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py b/torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py new file mode 100644 index 0000000000..d0eeba01cb --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py @@ -0,0 +1,164 @@ +import torch +import torchvision.models as models +import numpy as np +import os +from tqdm import tqdm +import transformers +from datasets import load_dataset +import random +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.autograd.functional import hvp, vhp + + +def group_product(xs, ys): + return [torch.sum(x * y) for (x, y) in zip(xs, ys)] + +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc.input_ids + +# utilities to make nn.Module functional +def del_attr(obj, names): + if len(names) == 1: + delattr(obj, names[0]) + else: + del_attr(getattr(obj, names[0]), names[1:]) + +def set_attr(obj, names, val): + if len(names) == 1: + setattr(obj, names[0], val) + else: + set_attr(getattr(obj, names[0]), names[1:], val) + +def make_functional(mod, layer_id): + orig_params = tuple(mod.parameters()) + # remove all the parameters in the model + selected_params=[] + selected_params_names=[] + + names = [] + for name, p in list(mod.named_parameters()): + if name.startswith("model.layers."+str(layer_id)+".self_attn.") or name.startswith("model.layers."+str(layer_id)+".mlp."): + selected_params.append(p) + selected_params_names.append(name) + del_attr(mod, name.split(".")) + names.append(name) + return orig_params, names, selected_params, selected_params_names + + + +def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): + + # use the functional model to load the weights back + def load_weights(mod, names, params, selected_params, selected_params_names): + for name, p in zip(names, params): + if name.startswith("model.layers."+str(layer_id)+".self_attn.") or name.startswith("model.layers."+str(layer_id)+".mlp."): + idx=selected_params_names.index(name) + set_attr(mod, name.split("."), selected_params[idx]) + else: + set_attr(mod, name.split("."), p) + for name, param in mod.named_parameters(): + if param.requires_grad: + print(f"Parameter {name} requires gradients.") + + # define the function to calculate the vhp + def f(*new_params): + load_weights(model, names, params, new_params, selected_params_names) + model.zero_grad() + outputs = model(inputs) + logits = outputs.logits + loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) + return loss + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error + with sdpa_kernel(SDPBackend.MATH): + + # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M + model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + model = model.to(device) + model.eval() + + criterion = torch.nn.CrossEntropyLoss() + + # load calibration dataset + seed = 0 + trainloader, testloader = get_wikitext2(128, 0, 2048, tokenizer) + + # make the model functional + params, names, selected_params, selected_params_names = make_functional(model, layer_id) + + # make params regular Tensors instead of nn.Parameter + params = tuple(p.detach() for p in params) + + # set requires_grad to True for the selected parameters + selected_params_tuple = tuple(p.detach().requires_grad_() for p in selected_params) + + trace_history = [] + vhv_c_history=[] + + for iteration in range(maxIter): + + print("iteration: ",iteration) + + # generate Rademacher random variables + v = [torch.randint_like(p, high=2) for p in selected_params_tuple] + for v_i in v: + v_i[v_i == 0] = -1 + + for i in tqdm(range(nsamples)): + inputs, labels = trainloader[i] + inputs = inputs.to(device) + labels = labels.to(device) + # if use testloader: + # inputs = testloader[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) + # labels = testloader[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) + + # get vector-Hessian product + _, vH = vhp(f, selected_params_tuple, tuple(v)) + + if i==0: + TvH = [torch.zeros(p.size()).to(device) for p in selected_params_tuple] + TvH = [TvH1 + vH1 + 0.0 for TvH1, vH1 in zip(TvH, vH)] + + + TvH = [TvH1 / float(nsamples) for TvH1 in TvH] + # get vHv + vHv = group_product(TvH, v) + vHv_c = np.array([i.to(torch.float32).cpu().numpy() for i in vHv]) + vhv_c_history.append(vHv_c) + trace = np.sum(vHv_c) + print("trace", trace) + trace_history.append(trace) + + print("Iteration Done") + print("avg: trace,", np.mean(trace_history)) + print("trace_history,", trace_history) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the Hessian trace') + parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') + parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') + parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') + args = parser.parse_args() + main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) From 88adf14e0e600db5635e25403e3bcef832c3227a Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Mon, 5 Aug 2024 14:01:00 -0700 Subject: [PATCH 2/2] renamed files --- .../prototype/mixed_precision/scripts/{FIT.py => fit.py} | 0 .../scripts/{Hessian_grad.py => hessian_grad.py} | 0 .../mixed_precision/scripts/{Hessian_vhp.py => hessian_vhp.py} | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename torchao/quantization/prototype/mixed_precision/scripts/{FIT.py => fit.py} (100%) rename torchao/quantization/prototype/mixed_precision/scripts/{Hessian_grad.py => hessian_grad.py} (100%) rename torchao/quantization/prototype/mixed_precision/scripts/{Hessian_vhp.py => hessian_vhp.py} (99%) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/FIT.py b/torchao/quantization/prototype/mixed_precision/scripts/fit.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/FIT.py rename to torchao/quantization/prototype/mixed_precision/scripts/fit.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/Hessian_grad.py rename to torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py similarity index 99% rename from torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py rename to torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py index d0eeba01cb..480365c66d 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/Hessian_vhp.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py @@ -143,7 +143,7 @@ def f(*new_params): vHv = group_product(TvH, v) vHv_c = np.array([i.to(torch.float32).cpu().numpy() for i in vHv]) vhv_c_history.append(vHv_c) - trace = np.sum(vHv_c) + trace = np.sum(np.abs(vHv_c)) print("trace", trace) trace_history.append(trace)