From 1ec9b07af170d17442b8bc2f916bcd6ac18a6dbd Mon Sep 17 00:00:00 2001 From: HDCharles Date: Sat, 20 Jul 2024 01:39:47 -0700 Subject: [PATCH] testing kv_cache quantization [WIP] Summary: the peak memory improvement is extremely small, tried a few things to fix this but didn't have any luck. Accuracy is very poor (text is unintelligible) tried to leave most recent token not quantized (since we have full fidelity information for whatever the current token is). That didn't solve the issue and resulted in a significant memory increase, may need to try affine quantization but currently more concerned with the lack of memory improvement. (see benchmark_results.txt for the results see kv_quant: True vs kv_quant: False for comparison.) i also took a memory trace you can get with (if you're a meta employee) jf download GCqU9BqGNUybzv8CABWUzUtOiPZ5bsIXAAAz --file "mem_trace_kvq.html" Test Plan: sh benchmarks.sh Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/benchmark_results.txt | 12 ++++ torchao/_models/llama/benchmarks.sh | 8 ++- torchao/_models/llama/generate.py | 36 +++++++++--- torchao/_models/llama/model.py | 61 +++++++++++++++------ 4 files changed, 90 insertions(+), 27 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index b02d4c2441..15abccbe69 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -17,3 +17,15 @@ 20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +# done with quantization of latest token +20240718131341, tok/s=108.87, mem/s=1438.62 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718131549, tok/s=103.15, mem/s=1363.06 GB/s, peak_mem=13.86 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718131820, tok/s=163.84, mem/s=1084.89 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718132103, tok/s=154.76, mem/s=1024.78 GB/s, peak_mem= 8.93 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +# done with full accuracy for latest token +20240718150644, tok/s=109.23, mem/s=1443.43 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718151152, tok/s=100.29, mem/s=1325.29 GB/s, peak_mem=14.14 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718151349, tok/s=166.08, mem/s=1099.70 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240718152147, tok/s=140.85, mem/s= 932.66 GB/s, peak_mem= 9.21 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b03866db23..614329b1e6 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -6,7 +6,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt # in readme -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt @@ -22,3 +22,9 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt + +##### +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index f3cbc0c8ef..8643eb646e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) + next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 - new_tokens.append(next_token.clone()) + new_tokens.append(next_token) callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) + new_probs.append(next_prob) cur_token = next_token.view(1, -1) return new_tokens, new_probs @@ -88,23 +89,32 @@ def generate( *, interactive: bool, callback = lambda x: x, + kv_cache_quantization: bool = False, **sampling_kwargs ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ - # create an empty tensor of the expected final shape and fill in the current tokens device = prompt.device T = prompt.numel() T_new = T + max_new_tokens seq = torch.empty(T_new, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - # setup model cache max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if kv_cache_quantization: + from model import QuantizedKVCache + # go through the model and do the swaps + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + _replace_with_custom_fn_if_matches_filter( + model, + QuantizedKVCache.from_float, + lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), + ) + # format model input x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) @@ -147,6 +157,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + kv_cache_quantization: bool = False, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -157,6 +168,7 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + # torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=1000000, trace_alloc_record_context=True) torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -179,9 +191,7 @@ def main( encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) prompt_length = encoded.size(0) - torch.manual_seed(1234) - - + torch.manual_seed(1234) if quantization: from torchao.quantization.quant_api import ( quantize_, @@ -276,7 +286,14 @@ def callback(x): callback=callback, temperature=temperature, top_k=top_k, + kv_cache_quantization=kv_cache_quantization, ) + # if i==3: + # snapshot = torch.cuda.memory._snapshot() + # from pickle import dump + # with open("mem_trace_kvq_no_comp" + '.pickle', 'wb') as f: + # dump(snapshot, f) + # breakpoint() if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue @@ -305,12 +322,13 @@ def callback(x): print(f"Model Size: {model_size:.02f} GB") if write_result: result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " + result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" result_txt += f"--profile {profile} " if profile else "" @@ -348,5 +366,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result ) diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index bea2eaffcb..37727d4cf0 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -85,7 +85,6 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torc def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] - if use_index_put_for_kv_cache: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) @@ -97,23 +96,51 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -# class QuantizedKVCache(nn.Module): -# def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): -# super().__init__() -# cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) -# self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.uint8)) -# self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.uint8)) -# self.register_buffer('k_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16)) -# self.register_buffer('v_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16)) + +# (Pdb) p k_val.shape +# torch.Size([1, 32, 6, 128]) +# (Pdb) p self.k_cache.shape +# torch.Size([1, 32, 208, 128]) so want final size to be 1,32,208,[1] + +from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine +from torchao.quantization.utils import quantize_activation_per_token_absmax + +class QuantizedKVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + scale_shape = (max_batch_size, n_heads, max_seq_length, 1) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) + self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) -# def update(self, input_pos, k_val, v_val): -# k_out = self.k_cache -# v_out = self.v_cache -# k_out[:, :, input_pos] = k_val -# v_out[:, :, input_pos] = v_val - -# @classmethod -# def from_kv_cache(cls, kv_cache): + def update(self, input_pos, k_val, v_val): + # k_out = self.k_cache*self.k_cache_scale + # v_out = self.v_cache*self.v_cache_scale + # k_out[:, :, input_pos] = k_val + # v_out[:, :, input_pos] = v_val + + q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) + self.k_cache[:, :, input_pos] = q_k_val + self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) + del k_val + + q_v_val, v_scale = quantize_activation_per_token_absmax(v_val) + self.k_cache[:, :, input_pos] = q_v_val + self.k_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) + del v_val + + # return k_out, v_out + return self.k_cache*self.k_cache_scale, self.v_cache*self.v_cache_scale + + @classmethod + def from_float(cls, kv_cache): + cache_shape = kv_cache.k_cache.shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + scale_dtype = kv_cache.k_cache.dtype + return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype) + class Transformer(nn.Module):