Skip to content

Commit

Permalink
Autoquantization work for benchmarks
Browse files Browse the repository at this point in the history
Summary:

autoquant wasn't working for llama benchmarks for a few reasons the main
one being that we were doing logging on prefill not decode_one_token. We
also weren't torch.compiling prefill which obviated the whole point of
autoquant benchmarking torch.compiled prefill shapes.

To fix this, new functionality was needed for autoquant, we needed an
option to not automatically end logging upon a single instance of
model.forward. The flag manual_do_autoquant now controls whether you
manually have to call model.do_autoquant() after logging is done, or
whether it happens automatically after a model forward run.

a few small other fixes were also made:
1) updated where generate.py resets cuda memory so as to not confound
   with torch.compilation memory usage
2) README updated with new numbers
3) better autoquant docstring
4) generalized model size code in generate.py and moved it to torchao.utils
5) reordered benchmarks so they match whats in the README

Test Plan: sh benchmarks.sh

python test_integration.py -k "test_autoquant_manual"

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jun 14, 2024
1 parent f7620fe commit 39e2205
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 68 deletions.
4 changes: 4 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
20240611230440, tok/s=149.32, mem/s= 988.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240611231027, tok/s= 9.35, mem/s= 61.94 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240611231759, tok/s= 9.56, mem/s= 126.32 GB/s, peak_mem= 8.53 GB, model_size=13.22 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

20240612192927, tok/s=129.98, mem/s=1731.86 GB/s, peak_mem= 9.63 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240612193544, tok/s=120.03, mem/s=1599.21 GB/s, peak_mem=11.21 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240612200857, tok/s=149.50, mem/s=1991.85 GB/s, peak_mem= 9.15 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
32 changes: 16 additions & 16 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
export CHECKPOINT_PATH=../../../../gpt-fast/checkpoints # path to checkpoints folder

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt
18 changes: 14 additions & 4 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def generate(
seq[:T] = prompt.view(-1)

# setup model cache
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
try:
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
except:
max_seq_length = T_new

with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

Expand Down Expand Up @@ -202,6 +206,7 @@ def main(


if quantization:
from torchao.quantization.autoquant import change_autoquantizable_to_quantized
from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
Expand All @@ -218,14 +223,19 @@ def main(
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
change_linear_weights_to_int4_woqtensors(model, groupsize=groupsize)
if "autoquant" == quantization:
model = autoquant(model)
model = autoquant(model, do_autoquant_after_run=False)

generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
2,
interactive=False
min(max_new_tokens, model.config.block_size),
interactive=False,
temperature=temperature,
top_k=top_k,
)

change_autoquantizable_to_quantized(model)

model_size = _get_model_size(model) / 1e9

if compile:
Expand Down
121 changes: 73 additions & 48 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ def to_quantized(self, error_on_unseen, **kwargs):
return self


# only want to do shape+final print a single time if multiple layers
# see/have same shapes so we gate on check_cache being empty for
# at least one of the class/shape combinations.
do_final_print = False
print_once = True
# only want to print shape (at start) and final result (at end)
# once per shape+quantization subclass combination.
ran_new_benchmarks = False
print_shape_once = True

def count_shapes(self, do_print=True):
differe_shape_count=0
Expand All @@ -123,27 +122,25 @@ def count_shapes(self, do_print=True):
shape_count = count_shapes(self, do_print=False)
for shapes_and_dtype, times_seen in self.logged_data.items():
if check_cache(q_cls, shapes_and_dtype) is None:
# only do final print if we have to autotune at least one cls/shape pair
do_final_print=True

# only print shapes once
if print_once == True:
print_once = False
if print_shape_once == True:
print_shape_once = False
count_shapes(self, do_print=True)

time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
ran_new_benchmarks=True
torch._dynamo.reset()
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
if shape_count is not None and shape_count > 1:
# print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done
if shape_count is not None and shape_count > 1 and ran_new_benchmarks:
print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
if best_time >= cur_time:
best_time = cur_time
best_cls = q_cls
# only print if this is the first time seeing some cls+shape combo,
# otherwise we will print the same thing for every layer.
if do_final_print:
# if no new benchmarking was done, don't print the final result, it will be the same as for another layer
if ran_new_benchmarks:
print(f"best_cls={best_cls}\n")
# TODO handle random cls args/kwargs? or should they be curried?
self = best_cls.from_float(self.weight)
Expand Down Expand Up @@ -433,11 +430,27 @@ def change_autoquantizable_to_quantized(model, **kwargs):
# TODO: Document all the modes
# TODO: Mode being a list is weird, should be a string or some object
@torch.no_grad()
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs):
def autoquant(
model,
example_input=None,
qtensor_class_list=DEFAULT_CLASS_LIST,
filter_fn=None, mode=["interpolate", .85],
manual_do_autoquant=False,
**aq_kwargs
):
"""
Wraps the given model in an AutoQuantWrapper. If `example_input` is provided, performs a forward pass on the input.
Otherwise, returns the wrapped model. The AutoQuantWrapper manages cases where the model is torch-compiled by first
performing autoquantization on the original model and then allowing the torch.compile run/tracing to occur.
Begins autoquantization. Autoquantization happens in three steps:
1) the model is searched for Linear layers whose weights are exchanged for AutoQuantizableLinearWeight
2) the user runs the model on one or more inputs, the details of the activation shape/dtype seen by
the AutoQuantizableLinearWeight are logged
3) for each AutoQuantizableLinearWeight, benchmarks are run for each member of the qtensor_class_list and
the fastest option is picked, resulting in a highly performant model
This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model.
If `example_input` is provided, this function also runs the model. This autoquant api can handle models which have already
had torch.compile applied to them, in which case, once the model is run and quantized, the torch.compile process normally
proceeds as well.
Args:
model (torch.nn.Module): The model to be autoquantized.
Expand All @@ -447,6 +460,8 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST,
filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"),
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
manual_do_autoquant (bool, optional): Whether to stop logging and do the autoquant after a single run (False) or to wait for
the user to call model.do_autoquant (True) so multiple inputs can be logged
**aq_kwargs: Additional keyword arguments for the autoquantization process.
Returns:
Expand All @@ -457,16 +472,6 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST,
torchao.autoquant(torch.compile(model))
model(*example_input)
"""
# the hook we will use to intercept the model forward and perform
# autoquantization
def autoquant_prehook(module, args, kwargs):
module.forward_log_only(*args, **kwargs)
change_autoquantizable_to_quantized(
module,
**aq_kwargs,
)
module.clean_up_autoquant_hooks_and_attrs()
return args, kwargs

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand All @@ -479,32 +484,52 @@ def autoquant_prehook(module, args, kwargs):
)

# access actual model of torch.compile wrapper if needed
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
if is_compiled:
real_model = model._orig_mod
else:
real_model = model

# we need a consistent way to run the model which bypasses both
# A) the torch.compile tracing (so we need to run the inner model directly)
# B) the autoquant_prehook we're about to register (so we call forward directly)
model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs)

# the autoquant_prehook intercepts the forward call and performs autoquantization
# and then deletes the hook. if model is a torch.compile wrapper, it then
# does the tracing/compile since the prehook is naturally followed by the normal.
# model run.
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)

if manual_do_autoquant:
# we don't want model.forward to trigger
# torch.compilation
if is_compiled:
real_model.old_forward = model.forward
model.forward = real_model.forward

# note the torch.compile wrapper eval_frame moved the assignment of any assigned
# attributes to the inner model, so we have to call delattr on the inner model
def clean_up_autoquant_hooks_and_attrs():
try:
# we want to automatically do autoquant after a single model run
# and have it occur before torch.compilation if applicable
else:
# the hook we will use to intercept the model forward and perform
# autoquantization
def autoquant_prehook(module, args, kwargs):
real_model.forward(*args, **kwargs)
module.do_autoquant()
return args, kwargs

# the autoquant_prehook intercepts the forward call, performs logging then
# does autoquantization. if model is a torch.compile wrapper, it then
# does the tracing/compile since the prehook is naturally followed by the normal.
# model run.
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)

# note the torch.compile wrapper (eval_frame) moves the assignment of any assigned
# attributes to the inner model that didn't exist before, so we have to call delattr on the inner model
def do_autoquant():
change_autoquantizable_to_quantized(
real_model,
**aq_kwargs,
)
if hasattr(real_model, "old_forward"):
model.forward = real_model.old_forward
delattr(real_model, "old_forward")
if hasattr(real_model, "do_autoquant"):
delattr(real_model, "do_autoquant")
if not manual_do_autoquant:
handle.remove()
delattr(real_model, "clean_up_autoquant_hooks_and_attrs")
delattr(real_model, "forward_log_only")
except:
pass
model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs

real_model.do_autoquant = do_autoquant

# if example input was provided, check it and run it
if isinstance(example_input, torch.Tensor):
Expand Down

0 comments on commit 39e2205

Please sign in to comment.