From 245ab4e82c82da8dab11335f1c53c8d2479f257a Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:36:25 -0400 Subject: [PATCH] Fixing linear_activation_tensor dynamic quant (#622) Summary: dynamic quant was broken for generate due to no repr function Test Plan: sh benchmarks.sh 20240806170037, tok/s= 9.54, mem/s= 63.14 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/linear_activation_quantized_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index dfe1f62de7..b8070aff69 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -39,6 +39,9 @@ def __init__( self.original_weight_tensor = original_weight_tensor self.input_quant_func = input_quant_func + def __repr__(self): + return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func})" + def __tensor_flatten__(self): return ["original_weight_tensor"], [self.input_quant_func]