Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Float8 benchmarking llama #1017

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt



export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
Expand All @@ -19,6 +17,14 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# Runs on H100, float8 is not supported on CUDA arch < 8.9
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt

# OTHER BENCHMARKS

Expand Down Expand Up @@ -58,4 +64,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
34 changes: 28 additions & 6 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

from tokenizer import get_tokenizer
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

def run_evaluation(
Expand Down Expand Up @@ -55,19 +58,16 @@ def run_evaluation(
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
# Load Model and Tokenizer

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, "cpu", precision)

if max_length is None:
max_length = model.config.block_size

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)


if quantization:
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
Expand Down Expand Up @@ -100,6 +100,9 @@ def run_evaluation(
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import InputRecorder
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
Expand All @@ -122,9 +125,24 @@ def run_evaluation(
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
if "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
elif granularity=="row":
granularity = PerRow()
else:
if granularity=="float8dq":
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
granularity = PerTensor()
else:
raise ValueError(f"Unknown granularity {granularity}")
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer
from torchao._models.llama.model import TransformerBlock

_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
# parse args from quantization string:
Expand Down Expand Up @@ -182,6 +200,9 @@ def run_evaluation(
if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
with torch.no_grad():
print("Running evaluation ...")
# avoid circular imports
from torchao._models._eval import TransformerEvalWrapper
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
Expand Down Expand Up @@ -209,7 +230,8 @@ def run_evaluation(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>"
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
"float8wo, float8dq, float8saq"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
Expand Down
12 changes: 10 additions & 2 deletions torchao/_models/llama/evals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@ python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quanti
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row

# Testing on additional tasks
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2'
16 changes: 15 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ def main(
fpx_weight_only,
uintx_weight_only,
autoquant,
unwrap_tensor_subclass
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -243,6 +246,17 @@ def main(
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
elif granularity=="row":
granularity = PerRow()
else:
granularity = PerTensor()
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down
20 changes: 20 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |

Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.

| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
drisspg marked this conversation as resolved.
Show resolved Hide resolved
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |

note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.

For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores
Expand Down Expand Up @@ -121,6 +132,15 @@ from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtenso
change_linear_weights_to_int8_dqtensors(model)
```

#### A8W8 Float8 Dynamic Quantization
jainapurva marked this conversation as resolved.
Show resolved Hide resolved

```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

#### A16W6 Floating Point WeightOnly Quantization

```python
Expand Down
Loading