Skip to content

Commit

Permalink
fixing scripts (#395)
Browse files Browse the repository at this point in the history
Summary:

a few bugfixes for scripts

1) convert_hf_checkpoint.py had a gpt-fast dependency that wasn't caught
   due to it being in the path

2) eval.py had a bug due to the switch to aqt apis
3) generate.py had a bug due to the deprecation of the old quant apis

Test Plan:

python eval.py
python eval.py -q int8dq --compile --limit 2
python eval.py -q int8wo --compile --limit 2
python eval.py -q int4wo-64 --compile --limit 2
python eval.py -q int4wo-64-gptq --compile
sh benchmarks.sh

(going to add the output results once they finish

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Jun 18, 2024
1 parent d0af941 commit 6b0ca2d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
7 changes: 1 addition & 6 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@
import json
import re
import shutil
import sys
from pathlib import Path
from typing import Optional

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from model import ModelArgs
from torchao._models.llama.model import ModelArgs


@torch.inference_mode()
Expand Down
8 changes: 5 additions & 3 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

)
from torchao.quantization.quant_api import (
quantize, int4wo, int8wo, int8da_int8w
quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass

)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
Expand Down Expand Up @@ -70,7 +70,7 @@ def run_evaluation(
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"

assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
Expand All @@ -83,9 +83,11 @@ def run_evaluation(
calibration_limit,
).get_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, precision=precision)
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs).to(device)
else:
unwrap_tensor_subclass(model)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
17 changes: 11 additions & 6 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,22 @@ def main(

if quantization:
from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int8_dqtensors,
quantize,
int8wo,
int8da_int8w,
int4wo,
autoquant,
unwrap_tensor_subclass
)

if "int8wo" in quantization:
change_linear_weights_to_int8_woqtensors(model)
quantize(model, int8wo())
if "int8dq" in quantization:
change_linear_weights_to_int8_dqtensors(model)
quantize(model, int8da_int8w())
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
change_linear_weights_to_int4_woqtensors(model, groupsize=groupsize)
quantize(model, int4wo(groupsize=groupsize))
if "autoquant" == quantization:
model = autoquant(model)
generate(
Expand All @@ -211,6 +213,9 @@ def main(
2,
interactive=False
)
else:
unwrap_tensor_subclass(model)


model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

Expand Down

0 comments on commit 6b0ca2d

Please sign in to comment.