From b8aaa960605a725ab17c50a02482081d4f5bc83f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 1 Aug 2024 10:51:33 -0700 Subject: [PATCH] fixes and improvements Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/benchmark_results.txt | 47 +++++++++++++++++++++ torchao/_models/llama/benchmarks.sh | 12 ++++-- torchao/_models/llama/generate.py | 23 +++++++--- torchao/_models/llama/model.py | 22 ++++------ 4 files changed, 80 insertions(+), 24 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 15abccbe69..c327db9a95 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -29,3 +29,50 @@ 20240718151152, tok/s=100.29, mem/s=1325.29 GB/s, peak_mem=14.14 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240718151349, tok/s=166.08, mem/s=1099.70 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240718152147, tok/s=140.85, mem/s= 932.66 GB/s, peak_mem= 9.21 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240731133002, tok/s=109.42, mem/s=1446.01 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731135838, tok/s=102.85, mem/s=1359.17 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731140259, tok/s=102.91, mem/s=1359.87 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731140646, tok/s=101.19, mem/s=1337.23 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731194813, tok/s=102.84, mem/s=1358.94 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731195225, tok/s=103.14, mem/s=1362.92 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731200747, tok/s=102.79, mem/s=1358.40 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731201145, tok/s=103.09, mem/s=1362.33 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 + +20240731201438, tok/s=109.42, mem/s=1446.00 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731201739, tok/s=102.58, mem/s=1355.51 GB/s, peak_mem=13.83 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731202121, tok/s=102.86, mem/s=1359.26 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731202505, tok/s=103.44, mem/s=1366.91 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 + +20240731212356, tok/s= 95.36, mem/s=1431.41 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731212822, tok/s= 93.92, mem/s=1409.76 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731213330, tok/s= 89.84, mem/s=1348.49 GB/s, peak_mem=17.28 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731213900, tok/s= 88.34, mem/s=1326.01 GB/s, peak_mem=17.24 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731215214, tok/s= 80.92, mem/s=1214.62 GB/s, peak_mem=19.80 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240731220805, tok/s= 78.54, mem/s=1178.91 GB/s, peak_mem=19.30 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 + +20240731223923, tok/s= 95.52, mem/s=1433.68 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731224400, tok/s= 90.39, mem/s=1356.73 GB/s, peak_mem=16.44 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240731224816, tok/s= 89.89, mem/s=1349.25 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731225411, tok/s= 84.96, mem/s=1275.21 GB/s, peak_mem=17.48 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240731230612, tok/s= 80.91, mem/s=1214.45 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240731232130, tok/s= 69.10, mem/s=1037.25 GB/s, peak_mem=20.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 + +20240801010740, tok/s= 95.45, mem/s=1432.64 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801011046, tok/s= 94.02, mem/s=1411.28 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801011513, tok/s= 89.96, mem/s=1350.32 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801011931, tok/s= 88.11, mem/s=1322.52 GB/s, peak_mem=17.20 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 + +20240801013354, tok/s= 95.45, mem/s=1432.67 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801013812, tok/s= 92.15, mem/s=1383.16 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801025927, tok/s= 89.88, mem/s=1349.14 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801030347, tok/s= 87.32, mem/s=1310.69 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801031549, tok/s= 80.91, mem/s=1214.39 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240801033011, tok/s= 74.72, mem/s=1121.50 GB/s, peak_mem=19.34 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 + +20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 614329b1e6..ed599de914 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,7 +1,7 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder -export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt @@ -12,7 +12,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -# export MODEL_REPO=meta-llama/Meta-Llama-3-8B +export MODEL_REPO=meta-llama/Meta-Llama-3-8B # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt @@ -26,5 +26,9 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf ##### python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --max_new_tokens 2048 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8643eb646e..beb1569bbb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -98,11 +98,15 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens device = prompt.device T = prompt.numel() - T_new = T + max_new_tokens - seq = torch.empty(T_new, dtype=prompt.dtype, device=device) + + # max_new_tokens can overflow block_size so we need to cap it + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # full prompt+output will be stored in seq + seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if kv_cache_quantization: @@ -117,14 +121,14 @@ def generate( # format model input - x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) + x, input_pos = prepare_inputs_for_model(prompt) # execute prefill next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) return seq @@ -191,7 +195,9 @@ def main( encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) prompt_length = encoded.size(0) - torch.manual_seed(1234) + torch.manual_seed(1234) + + if quantization: from torchao.quantization.quant_api import ( quantize_, @@ -303,7 +309,10 @@ def callback(x): t = time.perf_counter() - t0 if not interactive: - print(tokenizer.decode(y.tolist())) + tok_list = y.tolist() + # truncate text after end of string token + tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())] + print(tokenizer.decode(tokens)) else: print() tokens_generated = y.size(0) - prompt_length diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 37727d4cf0..5c37adcdd7 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -12,7 +12,7 @@ from torch.nn import functional as F from torchao.utils import find_multiple -def prepare_inputs_for_model(inps, max_new_tokens=1): +def prepare_inputs_for_model(inps): # this is because input from lm-eval is 2d if inps.dim() > 2: raise ValueError(f"Expected input to be of dim 1 or 2, but got {inps.dim()}") @@ -116,23 +116,19 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) def update(self, input_pos, k_val, v_val): - # k_out = self.k_cache*self.k_cache_scale - # v_out = self.v_cache*self.v_cache_scale - # k_out[:, :, input_pos] = k_val - # v_out[:, :, input_pos] = v_val - q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) self.k_cache[:, :, input_pos] = q_k_val self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) - del k_val + k_out = self.k_cache*self.k_cache_scale + k_out[:, :, input_pos] = k_val q_v_val, v_scale = quantize_activation_per_token_absmax(v_val) - self.k_cache[:, :, input_pos] = q_v_val - self.k_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) - del v_val - - # return k_out, v_out - return self.k_cache*self.k_cache_scale, self.v_cache*self.v_cache_scale + self.v_cache[:, :, input_pos] = q_v_val + self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) + v_out = self.v_cache*self.v_cache_scale + v_out[:, :, input_pos] = v_val + + return k_out, v_out @classmethod def from_float(cls, kv_cache):