Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add uintx quant to generate and eval #811

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
int8_weight_only,
int8_dynamic_activation_int8_weight,
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
Expand Down Expand Up @@ -79,6 +80,15 @@ def run_evaluation(
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
Expand Down Expand Up @@ -125,7 +135,7 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq")
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<group_size>")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand Down
15 changes: 13 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def generate(
# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)

seq = torch.cat((seq[:T+1], *generated_tokens))

return seq
Expand Down Expand Up @@ -208,6 +208,7 @@ def main(
int8_dynamic_activation_int8_weight,
int4_weight_only,
fpx_weight_only,
uintx_weight_only,
autoquant,
unwrap_tensor_subclass
)
Expand Down Expand Up @@ -271,6 +272,16 @@ def main(
model.reset_caches()
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "uintx" in quantization:
# uintx-nbits-group_size
# "uintx-2-64"
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down Expand Up @@ -440,7 +451,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<group_size>')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l
| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 |
| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 |
| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 |
| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 |
| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 |
| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 |

| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 |
| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |

note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
Expand Down
Loading