diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 960bdd4f17..421d5b0608 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -14,3 +14,7 @@ 20240611230440, tok/s=149.32, mem/s= 988.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240611231027, tok/s= 9.35, mem/s= 61.94 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240611231759, tok/s= 9.56, mem/s= 126.32 GB/s, peak_mem= 8.53 GB, model_size=13.22 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240612192927, tok/s=129.98, mem/s=1731.86 GB/s, peak_mem= 9.63 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240612193544, tok/s=120.03, mem/s=1599.21 GB/s, peak_mem=11.21 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240612200857, tok/s=149.50, mem/s=1991.85 GB/s, peak_mem= 9.15 GB, model_size=13.32 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b345ed3f28..345959e58a 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,21 +1,21 @@ export CHECKPOINT_PATH=../../../../gpt-fast/checkpoints # path to checkpoints folder -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt +# export MODEL_REPO=meta-llama/Meta-Llama-3-8B +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1f5380a888..81dd4e7b1e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -109,7 +109,11 @@ def generate( seq[:T] = prompt.view(-1) # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + try: + max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + except: + max_seq_length = T_new + with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) @@ -202,6 +206,7 @@ def main( if quantization: + from torchao.quantization.autoquant import change_autoquantizable_to_quantized from torchao.quantization.quant_api import ( change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_woqtensors, @@ -218,14 +223,19 @@ def main( assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" change_linear_weights_to_int4_woqtensors(model, groupsize=groupsize) if "autoquant" == quantization: - model = autoquant(model) + model = autoquant(model, do_autoquant_after_run=False) + generate( model, encode_tokens(tokenizer, prompt, bos=True, device=device), - 2, - interactive=False + min(max_new_tokens, model.config.block_size), + interactive=False, + temperature=temperature, + top_k=top_k, ) + change_autoquantizable_to_quantized(model) + model_size = _get_model_size(model) / 1e9 if compile: diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ff0889467c..514a3dfbaa 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -97,11 +97,10 @@ def to_quantized(self, error_on_unseen, **kwargs): return self - # only want to do shape+final print a single time if multiple layers - # see/have same shapes so we gate on check_cache being empty for - # at least one of the class/shape combinations. - do_final_print = False - print_once = True + # only want to print shape (at start) and final result (at end) + # once per shape+quantization subclass combination. + ran_new_benchmarks = False + print_shape_once = True def count_shapes(self, do_print=True): differe_shape_count=0 @@ -123,27 +122,25 @@ def count_shapes(self, do_print=True): shape_count = count_shapes(self, do_print=False) for shapes_and_dtype, times_seen in self.logged_data.items(): if check_cache(q_cls, shapes_and_dtype) is None: - # only do final print if we have to autotune at least one cls/shape pair - do_final_print=True - # only print shapes once - if print_once == True: - print_once = False + if print_shape_once == True: + print_shape_once = False count_shapes(self, do_print=True) time_for_best_shape = check_cache(best_cls, shapes_and_dtype) time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) + ran_new_benchmarks=True torch._dynamo.reset() cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen - if shape_count is not None and shape_count > 1: + # print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done + if shape_count is not None and shape_count > 1 and ran_new_benchmarks: print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") if best_time >= cur_time: best_time = cur_time best_cls = q_cls - # only print if this is the first time seeing some cls+shape combo, - # otherwise we will print the same thing for every layer. - if do_final_print: + # if no new benchmarking was done, don't print the final result, it will be the same as for another layer + if ran_new_benchmarks: print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) @@ -433,11 +430,27 @@ def change_autoquantizable_to_quantized(model, **kwargs): # TODO: Document all the modes # TODO: Mode being a list is weird, should be a string or some object @torch.no_grad() -def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs): +def autoquant( + model, + example_input=None, + qtensor_class_list=DEFAULT_CLASS_LIST, + filter_fn=None, mode=["interpolate", .85], + manual_do_autoquant=False, + **aq_kwargs +): """ - Wraps the given model in an AutoQuantWrapper. If `example_input` is provided, performs a forward pass on the input. - Otherwise, returns the wrapped model. The AutoQuantWrapper manages cases where the model is torch-compiled by first - performing autoquantization on the original model and then allowing the torch.compile run/tracing to occur. + Begins autoquantization. Autoquantization happens in three steps: + + 1) the model is searched for Linear layers whose weights are exchanged for AutoQuantizableLinearWeight + 2) the user runs the model on one or more inputs, the details of the activation shape/dtype seen by + the AutoQuantizableLinearWeight are logged + 3) for each AutoQuantizableLinearWeight, benchmarks are run for each member of the qtensor_class_list and + the fastest option is picked, resulting in a highly performant model + + This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model. + If `example_input` is provided, this function also runs the model. This autoquant api can handle models which have already + had torch.compile applied to them, in which case, once the model is run and quantized, the torch.compile process normally + proceeds as well. Args: model (torch.nn.Module): The model to be autoquantized. @@ -447,6 +460,8 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. + manual_do_autoquant (bool, optional): Whether to stop logging and do the autoquant after a single run (False) or to wait for + the user to call model.do_autoquant (True) so multiple inputs can be logged **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -457,16 +472,6 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, torchao.autoquant(torch.compile(model)) model(*example_input) """ - # the hook we will use to intercept the model forward and perform - # autoquantization - def autoquant_prehook(module, args, kwargs): - module.forward_log_only(*args, **kwargs) - change_autoquantizable_to_quantized( - module, - **aq_kwargs, - ) - module.clean_up_autoquant_hooks_and_attrs() - return args, kwargs # perform initial swap from linear weights # to AutoQuantizableLinearWeight @@ -479,32 +484,52 @@ def autoquant_prehook(module, args, kwargs): ) # access actual model of torch.compile wrapper if needed - if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + if is_compiled: real_model = model._orig_mod else: real_model = model - # we need a consistent way to run the model which bypasses both - # A) the torch.compile tracing (so we need to run the inner model directly) - # B) the autoquant_prehook we're about to register (so we call forward directly) - model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs) - - # the autoquant_prehook intercepts the forward call and performs autoquantization - # and then deletes the hook. if model is a torch.compile wrapper, it then - # does the tracing/compile since the prehook is naturally followed by the normal. - # model run. - handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + if manual_do_autoquant: + # we don't want model.forward to trigger + # torch.compilation + if is_compiled: + real_model.old_forward = model.forward + model.forward = real_model.forward - # note the torch.compile wrapper eval_frame moved the assignment of any assigned - # attributes to the inner model, so we have to call delattr on the inner model - def clean_up_autoquant_hooks_and_attrs(): - try: + # we want to automatically do autoquant after a single model run + # and have it occur before torch.compilation if applicable + else: + # the hook we will use to intercept the model forward and perform + # autoquantization + def autoquant_prehook(module, args, kwargs): + real_model.forward(*args, **kwargs) + module.do_autoquant() + return args, kwargs + + # the autoquant_prehook intercepts the forward call, performs logging then + # does autoquantization. if model is a torch.compile wrapper, it then + # does the tracing/compile since the prehook is naturally followed by the normal. + # model run. + handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + # note the torch.compile wrapper (eval_frame) moves the assignment of any assigned + # attributes to the inner model that didn't exist before, so we have to call delattr on the inner model + def do_autoquant(): + change_autoquantizable_to_quantized( + real_model, + **aq_kwargs, + ) + if hasattr(real_model, "old_forward"): + model.forward = real_model.old_forward + delattr(real_model, "old_forward") + if hasattr(real_model, "do_autoquant"): + delattr(real_model, "do_autoquant") + if not manual_do_autoquant: handle.remove() - delattr(real_model, "clean_up_autoquant_hooks_and_attrs") - delattr(real_model, "forward_log_only") - except: - pass - model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs + + real_model.do_autoquant = do_autoquant # if example input was provided, check it and run it if isinstance(example_input, torch.Tensor):