From 39e2205d5daf7f8559fdcb868ef94b21914f3caa Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 12 Jun 2024 21:18:51 -0700 Subject: [PATCH] Autoquantization work for benchmarks Summary: autoquant wasn't working for llama benchmarks for a few reasons the main one being that we were doing logging on prefill not decode_one_token. We also weren't torch.compiling prefill which obviated the whole point of autoquant benchmarking torch.compiled prefill shapes. To fix this, new functionality was needed for autoquant, we needed an option to not automatically end logging upon a single instance of model.forward. The flag manual_do_autoquant now controls whether you manually have to call model.do_autoquant() after logging is done, or whether it happens automatically after a model forward run. a few small other fixes were also made: 1) updated where generate.py resets cuda memory so as to not confound with torch.compilation memory usage 2) README updated with new numbers 3) better autoquant docstring 4) generalized model size code in generate.py and moved it to torchao.utils 5) reordered benchmarks so they match whats in the README Test Plan: sh benchmarks.sh python test_integration.py -k "test_autoquant_manual" Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/benchmark_results.txt | 4 + torchao/_models/llama/benchmarks.sh | 32 +++--- torchao/_models/llama/generate.py | 18 ++- torchao/quantization/autoquant.py | 121 ++++++++++++-------- 4 files changed, 107 insertions(+), 68 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 960bdd4f17..421d5b0608 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -14,3 +14,7 @@ 20240611230440, tok/s=149.32, mem/s= 988.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240611231027, tok/s= 9.35, mem/s= 61.94 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240611231759, tok/s= 9.56, mem/s= 126.32 GB/s, peak_mem= 8.53 GB, model_size=13.22 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240612192927, tok/s=129.98, mem/s=1731.86 GB/s, peak_mem= 9.63 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240612193544, tok/s=120.03, mem/s=1599.21 GB/s, peak_mem=11.21 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240612200857, tok/s=149.50, mem/s=1991.85 GB/s, peak_mem= 9.15 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b345ed3f28..345959e58a 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,21 +1,21 @@ export CHECKPOINT_PATH=../../../../gpt-fast/checkpoints # path to checkpoints folder -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt +# export MODEL_REPO=meta-llama/Meta-Llama-3-8B +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1f5380a888..81dd4e7b1e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -109,7 +109,11 @@ def generate( seq[:T] = prompt.view(-1) # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + try: + max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + except: + max_seq_length = T_new + with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) @@ -202,6 +206,7 @@ def main( if quantization: + from torchao.quantization.autoquant import change_autoquantizable_to_quantized from torchao.quantization.quant_api import ( change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_woqtensors, @@ -218,14 +223,19 @@ def main( assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" change_linear_weights_to_int4_woqtensors(model, groupsize=groupsize) if "autoquant" == quantization: - model = autoquant(model) + model = autoquant(model, do_autoquant_after_run=False) + generate( model, encode_tokens(tokenizer, prompt, bos=True, device=device), - 2, - interactive=False + min(max_new_tokens, model.config.block_size), + interactive=False, + temperature=temperature, + top_k=top_k, ) + change_autoquantizable_to_quantized(model) + model_size = _get_model_size(model) / 1e9 if compile: diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ff0889467c..514a3dfbaa 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -97,11 +97,10 @@ def to_quantized(self, error_on_unseen, **kwargs): return self - # only want to do shape+final print a single time if multiple layers - # see/have same shapes so we gate on check_cache being empty for - # at least one of the class/shape combinations. - do_final_print = False - print_once = True + # only want to print shape (at start) and final result (at end) + # once per shape+quantization subclass combination. + ran_new_benchmarks = False + print_shape_once = True def count_shapes(self, do_print=True): differe_shape_count=0 @@ -123,27 +122,25 @@ def count_shapes(self, do_print=True): shape_count = count_shapes(self, do_print=False) for shapes_and_dtype, times_seen in self.logged_data.items(): if check_cache(q_cls, shapes_and_dtype) is None: - # only do final print if we have to autotune at least one cls/shape pair - do_final_print=True - # only print shapes once - if print_once == True: - print_once = False + if print_shape_once == True: + print_shape_once = False count_shapes(self, do_print=True) time_for_best_shape = check_cache(best_cls, shapes_and_dtype) time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) + ran_new_benchmarks=True torch._dynamo.reset() cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen - if shape_count is not None and shape_count > 1: + # print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done + if shape_count is not None and shape_count > 1 and ran_new_benchmarks: print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") if best_time >= cur_time: best_time = cur_time best_cls = q_cls - # only print if this is the first time seeing some cls+shape combo, - # otherwise we will print the same thing for every layer. - if do_final_print: + # if no new benchmarking was done, don't print the final result, it will be the same as for another layer + if ran_new_benchmarks: print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) @@ -433,11 +430,27 @@ def change_autoquantizable_to_quantized(model, **kwargs): # TODO: Document all the modes # TODO: Mode being a list is weird, should be a string or some object @torch.no_grad() -def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs): +def autoquant( + model, + example_input=None, + qtensor_class_list=DEFAULT_CLASS_LIST, + filter_fn=None, mode=["interpolate", .85], + manual_do_autoquant=False, + **aq_kwargs +): """ - Wraps the given model in an AutoQuantWrapper. If `example_input` is provided, performs a forward pass on the input. - Otherwise, returns the wrapped model. The AutoQuantWrapper manages cases where the model is torch-compiled by first - performing autoquantization on the original model and then allowing the torch.compile run/tracing to occur. + Begins autoquantization. Autoquantization happens in three steps: + + 1) the model is searched for Linear layers whose weights are exchanged for AutoQuantizableLinearWeight + 2) the user runs the model on one or more inputs, the details of the activation shape/dtype seen by + the AutoQuantizableLinearWeight are logged + 3) for each AutoQuantizableLinearWeight, benchmarks are run for each member of the qtensor_class_list and + the fastest option is picked, resulting in a highly performant model + + This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model. + If `example_input` is provided, this function also runs the model. This autoquant api can handle models which have already + had torch.compile applied to them, in which case, once the model is run and quantized, the torch.compile process normally + proceeds as well. Args: model (torch.nn.Module): The model to be autoquantized. @@ -447,6 +460,8 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. + manual_do_autoquant (bool, optional): Whether to stop logging and do the autoquant after a single run (False) or to wait for + the user to call model.do_autoquant (True) so multiple inputs can be logged **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -457,16 +472,6 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, torchao.autoquant(torch.compile(model)) model(*example_input) """ - # the hook we will use to intercept the model forward and perform - # autoquantization - def autoquant_prehook(module, args, kwargs): - module.forward_log_only(*args, **kwargs) - change_autoquantizable_to_quantized( - module, - **aq_kwargs, - ) - module.clean_up_autoquant_hooks_and_attrs() - return args, kwargs # perform initial swap from linear weights # to AutoQuantizableLinearWeight @@ -479,32 +484,52 @@ def autoquant_prehook(module, args, kwargs): ) # access actual model of torch.compile wrapper if needed - if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + if is_compiled: real_model = model._orig_mod else: real_model = model - # we need a consistent way to run the model which bypasses both - # A) the torch.compile tracing (so we need to run the inner model directly) - # B) the autoquant_prehook we're about to register (so we call forward directly) - model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs) - - # the autoquant_prehook intercepts the forward call and performs autoquantization - # and then deletes the hook. if model is a torch.compile wrapper, it then - # does the tracing/compile since the prehook is naturally followed by the normal. - # model run. - handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + if manual_do_autoquant: + # we don't want model.forward to trigger + # torch.compilation + if is_compiled: + real_model.old_forward = model.forward + model.forward = real_model.forward - # note the torch.compile wrapper eval_frame moved the assignment of any assigned - # attributes to the inner model, so we have to call delattr on the inner model - def clean_up_autoquant_hooks_and_attrs(): - try: + # we want to automatically do autoquant after a single model run + # and have it occur before torch.compilation if applicable + else: + # the hook we will use to intercept the model forward and perform + # autoquantization + def autoquant_prehook(module, args, kwargs): + real_model.forward(*args, **kwargs) + module.do_autoquant() + return args, kwargs + + # the autoquant_prehook intercepts the forward call, performs logging then + # does autoquantization. if model is a torch.compile wrapper, it then + # does the tracing/compile since the prehook is naturally followed by the normal. + # model run. + handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + # note the torch.compile wrapper (eval_frame) moves the assignment of any assigned + # attributes to the inner model that didn't exist before, so we have to call delattr on the inner model + def do_autoquant(): + change_autoquantizable_to_quantized( + real_model, + **aq_kwargs, + ) + if hasattr(real_model, "old_forward"): + model.forward = real_model.old_forward + delattr(real_model, "old_forward") + if hasattr(real_model, "do_autoquant"): + delattr(real_model, "do_autoquant") + if not manual_do_autoquant: handle.remove() - delattr(real_model, "clean_up_autoquant_hooks_and_attrs") - delattr(real_model, "forward_log_only") - except: - pass - model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs + + real_model.do_autoquant = do_autoquant # if example input was provided, check it and run it if isinstance(example_input, torch.Tensor):