diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 6582832f6b..4ba04d7c7a 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -9,8 +9,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt - - export MODEL_REPO=meta-llama/Meta-Llama-3-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt @@ -19,6 +17,14 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# Runs on H100, float8 is not supported on CUDA arch < 8.9 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt # OTHER BENCHMARKS @@ -58,4 +64,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt \ No newline at end of file +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index d495c2065b..3f55a22394 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -20,13 +20,16 @@ fpx_weight_only, uintx_weight_only, unwrap_tensor_subclass, + float8_weight_only, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, ) +from torchao.quantization.observer import PerRow, PerTensor from torchao._models._eval import TransformerEvalWrapper, InputRecorder +from torchao._models.llama.model import prepare_inputs_for_model from tokenizer import get_tokenizer import time -from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer -from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def run_evaluation( @@ -55,19 +58,16 @@ def run_evaluation( tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) # Load Model and Tokenizer - print("Loading model ...") t0 = time.time() model = _load_model(checkpoint_path, "cpu", precision) if max_length is None: max_length = model.config.block_size - device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - if quantization: if "int8wo" in quantization: quantize_(model, int8_weight_only()) @@ -100,6 +100,9 @@ def run_evaluation( from torchao.dtypes import MarlinSparseLayoutType quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) if "int4wo" in quantization and "gptq" in quantization: + # avoid circular imports + from torchao._models._eval import InputRecorder + from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}" @@ -122,9 +125,24 @@ def run_evaluation( else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + if "float8wo" in quantization: + quantize_(model, float8_weight_only()) + if "float8dq" in quantization: + granularity = str(quantization.split("-")[-1]) + if granularity=="tensor": + granularity = PerTensor() + elif granularity=="row": + granularity = PerRow() + else: + if granularity=="float8dq": + granularity = PerTensor() + else: + raise ValueError(f"Unknown granularity {granularity}") + quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) if "autoround" in quantization: from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ from transformers import AutoTokenizer + from torchao._models.llama.model import TransformerBlock _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) # parse args from quantization string: @@ -182,6 +200,9 @@ def run_evaluation( if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) with torch.no_grad(): + print("Running evaluation ...") + # avoid circular imports + from torchao._models._eval import TransformerEvalWrapper TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, @@ -209,7 +230,8 @@ def run_evaluation( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, " "int4wo--gptq, autoquant, autoquant-int4, int4wo--hqq, " "uintx--, uintx---hqq, sparse-marlin, " - "autoround---------" + "autoround---------, " + "float8wo, float8dq, float8saq" ), ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') diff --git a/torchao/_models/llama/evals.sh b/torchao/_models/llama/evals.sh index 2210faa7a2..ee57c422af 100644 --- a/torchao/_models/llama/evals.sh +++ b/torchao/_models/llama/evals.sh @@ -11,5 +11,13 @@ python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quanti export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2' -python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge' +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row + +# Testing on additional tasks +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge' +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2' diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf9..4aac1b216e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -210,8 +210,11 @@ def main( fpx_weight_only, uintx_weight_only, autoquant, - unwrap_tensor_subclass + unwrap_tensor_subclass, + float8_weight_only, + float8_dynamic_activation_float8_weight, ) + from torchao.quantization.observer import PerTensor, PerRow if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: @@ -243,6 +246,17 @@ def main( dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + if "float8wo" in quantization: + quantize_(model, float8_weight_only()) + if "float8dq" in quantization: + granularity = str(quantization.split("-")[-1]) + if granularity=="tensor": + granularity = PerTensor() + elif granularity=="row": + granularity = PerRow() + else: + granularity = PerTensor() + quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) if "autoquant" in quantization: if "autoquant-int4" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index d428d694a1..c936b7ef83 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -21,6 +21,17 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. + +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 | +| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 | +| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 | +| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 | +| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 | +| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 | + note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores @@ -121,6 +132,15 @@ from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtenso change_linear_weights_to_int8_dqtensors(model) ``` +#### A8W8 Float8 Dynamic Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight +from torchao.quantization.observer import PerTensor +quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) +``` + #### A16W6 Floating Point WeightOnly Quantization ```python