Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding kv_cache quantization #532

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
llama 2
20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
Expand All @@ -8,6 +9,7 @@
20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

llama 3
20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
Expand All @@ -17,3 +19,11 @@
20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

kv cache quantization:
20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
8 changes: 8 additions & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
43 changes: 33 additions & 10 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token.clone())
new_tokens.append(next_token)
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)

return new_tokens, new_probs
Expand All @@ -88,6 +89,7 @@ def generate(
*,
interactive: bool,
callback = lambda x: x,
kv_cache_quantization: bool = False,
**sampling_kwargs
) -> torch.Tensor:
"""
Expand All @@ -97,14 +99,27 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.numel()
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)

# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T

# full prompt+output will be stored in seq
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)

# setup model cache
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
# setup model caches
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if kv_cache_quantization:
from model import AffineQuantizedKVCache
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
_replace_with_custom_fn_if_matches_filter(
model,
AffineQuantizedKVCache.from_float,
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
)


# format model input
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
Expand All @@ -113,8 +128,9 @@ def generate(
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
seq[T] = next_token

# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)

return seq
Expand Down Expand Up @@ -147,6 +163,7 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
kv_cache_quantization: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand Down Expand Up @@ -276,6 +293,7 @@ def callback(x):
callback=callback,
temperature=temperature,
top_k=top_k,
kv_cache_quantization=kv_cache_quantization,
)
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand All @@ -286,7 +304,10 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
print(tokenizer.decode(y.tolist()))
tok_list = y.tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))
else:
print()
tokens_generated = y.size(0) - prompt_length
Expand All @@ -305,12 +326,13 @@ def callback(x):
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--profile {profile} " if profile else ""
Expand All @@ -337,6 +359,7 @@ def callback(x):
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand All @@ -347,5 +370,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
)
38 changes: 38 additions & 0 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import functional as F
from torchao.utils import find_multiple

# TODO remove suplerfluous arg
def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
if inps.dim() > 2:
Expand Down Expand Up @@ -97,6 +98,43 @@ def update(self, input_pos, k_val, v_val):

return k_out, v_out


from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
from torchao.quantization.utils import quantize_activation_per_token_absmax

class AffineQuantizedKVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))

def update(self, input_pos, k_val, v_val):
# quantize current k_val and store it in the cache
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
self.k_cache[:, :, input_pos] = q_k_val
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
k_out = self.k_cache*self.k_cache_scale
k_out[:, :, input_pos] = k_val

q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
self.v_cache[:, :, input_pos] = q_v_val
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
v_out = self.v_cache*self.v_cache_scale
v_out[:, :, input_pos] = v_val

return k_out, v_out

@classmethod
def from_float(cls, kv_cache):
cache_shape = kv_cache.k_cache.shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
scale_dtype = kv_cache.k_cache.dtype
return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype)

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
Expand Down
Loading