Skip to content

Commit

Permalink
Add uintx quant to generate and eval (#811)
Browse files Browse the repository at this point in the history
Summary:
att

Also rerun the benchmarks/eval for llama2/llama3 to get most recent perf/acc data

Test Plan:
torchao/_models/llama/generate.py
torchao/_models/llama/eval.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Sep 5, 2024
1 parent 9d5527a commit e05635e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
12 changes: 11 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
int8_weight_only,
int8_dynamic_activation_int8_weight,
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
Expand Down Expand Up @@ -79,6 +80,15 @@ def run_evaluation(
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
Expand Down Expand Up @@ -125,7 +135,7 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq")
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<group_size>")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand Down
15 changes: 13 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def generate(
# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)

seq = torch.cat((seq[:T+1], *generated_tokens))

return seq
Expand Down Expand Up @@ -208,6 +208,7 @@ def main(
int8_dynamic_activation_int8_weight,
int4_weight_only,
fpx_weight_only,
uintx_weight_only,
autoquant,
unwrap_tensor_subclass
)
Expand Down Expand Up @@ -271,6 +272,16 @@ def main(
model.reset_caches()
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down Expand Up @@ -440,7 +451,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<group_size>')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l
| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 |
| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 |
| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 |
| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 |
| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 |
| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 |
| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 |
| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |

note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
Expand Down

0 comments on commit e05635e

Please sign in to comment.