diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 6522fd9757..bf7448e023 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -18,6 +18,7 @@ int8_weight_only, int8_dynamic_activation_int8_weight, fpx_weight_only, + uintx_weight_only, unwrap_tensor_subclass, ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -79,6 +80,15 @@ def run_evaluation( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) + if "uintx" in quantization: + # uintx-nbits-group_size + # "uintx-2-64" + _quant_args = quantization.split("-") + nbits = int(_quant_args[1]) + _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + dtype = _NBITS_TO_DTYPE[nbits] + group_size = int(_quant_args[2]) + quantize_(model, uintx_weight_only(dtype, group_size)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -125,7 +135,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 089f247656..7d12dc2757 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -128,7 +128,7 @@ def generate( # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - + seq = torch.cat((seq[:T+1], *generated_tokens)) return seq @@ -208,6 +208,7 @@ def main( int8_dynamic_activation_int8_weight, int4_weight_only, fpx_weight_only, + uintx_weight_only, autoquant, unwrap_tensor_subclass ) @@ -271,6 +272,16 @@ def main( model.reset_caches() if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) + if "uintx" in quantization: + # uintx-nbits-group_size + # "uintx-2-64" + _quant_args = quantization.split("-") + nbits = int(_quant_args[1]) + assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" + _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + dtype = _NBITS_TO_DTYPE[nbits] + group_size = int(_quant_args[2]) + quantize_(model, uintx_weight_only(dtype, group_size)) if "autoquant" in quantization: if "autoquant-int4" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) @@ -440,7 +451,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------, uintx--') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 6112abd027..e0e97bb3e9 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -12,6 +12,8 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | | | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | | | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | +| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 | +| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 | | | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | | Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | @@ -19,6 +21,8 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | +| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 | +| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.