Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int8 KVCache Quant in Vllm #1507

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
ce271bc
support kv cache quantization
Sep 19, 2023
f8b0b05
fix python code
Sep 19, 2023
b1560db
merge and reformat
Sep 20, 2023
5c672ec
support generating kv quant parameters and evaluting kv quant models
Sep 27, 2023
f8d6b99
modify test functions
Sep 28, 2023
f8427e3
fix test code
Sep 28, 2023
df286fe
fix test attention
Sep 28, 2023
b2d9b8c
modify attention kernel test using pytest
Oct 12, 2023
c5a1a73
fix quant parameter passing
Oct 16, 2023
fbed95c
code clean
Oct 30, 2023
f396ed3
code clean
Oct 30, 2023
ad8f950
Merge branch 'main' into kv_quant
AniZpZ Nov 2, 2023
2543722
code format
Nov 3, 2023
4226683
code format
Nov 3, 2023
df15d44
fix merge
Nov 15, 2023
872d156
fix reshape_and_cache_quantized
Nov 20, 2023
8c29013
tmp fix
Nov 22, 2023
8b5278d
tmp fix2
Nov 22, 2023
d8a9d4a
update kv-quant kernels
Nov 23, 2023
0b06f96
add kv-quant kernel tests
Nov 23, 2023
734dcc6
support kv-quant
Nov 23, 2023
31c4083
code format
Nov 24, 2023
16bccc4
fix work bugs
Nov 24, 2023
dd527fc
fix unit test
Nov 27, 2023
104fb9b
fix unit test
Nov 29, 2023
580566c
fix kv-quant args
Dec 5, 2023
88ba3c0
fix attention params
Dec 18, 2023
e2ff5a6
Merge tag 'v0.2.7' into kv_quant_v0.2.7
Jan 16, 2024
3065a32
format code
Jan 16, 2024
a896eb3
add .buildkite
Jan 16, 2024
4072871
merge with remote branch 'vllm/main'
Feb 4, 2024
c0d3895
Merge branch 'kv_quant_merge' into kv_quant
Feb 5, 2024
f670d3c
Merge pull request #13 in wm_ai/project_v from tmp to kv_quant - <mer…
Feb 5, 2024
666549d
Merge branch 'main' into kv_quant
AniZpZ Feb 5, 2024
16bb483
fix compile issue
Feb 5, 2024
ca1fcb3
fix unit test issue
Feb 5, 2024
33f9d53
fix issues
Feb 7, 2024
594ec3f
support exporting kv quant params for transformers>=4.36.0
Feb 7, 2024
c37770b
fix benchmarks for kv cache int8
Feb 7, 2024
815eda7
Merge branch 'main' into kv_quant
HandH1998 Feb 7, 2024
14ec0ca
fix supporting kv cache int8 for specified models
Feb 7, 2024
2ff0e20
add int8_kv_cache.rst
Feb 7, 2024
5744c38
code format
Feb 8, 2024
cf7d939
code format
Feb 8, 2024
d79a96e
code format
Feb 19, 2024
9a2c2c6
code format
Feb 19, 2024
b1d4ce3
modify int8 kv cache doc
Feb 19, 2024
74013b7
fix conflicts
Mar 25, 2024
128cbae
fix conflicts
Mar 25, 2024
e24d431
fix conflicts
Mar 26, 2024
2f38a1c
fix rocm compile
Mar 26, 2024
74d706e
code format
Mar 26, 2024
a999930
fix rocm compile
Mar 26, 2024
98ef941
fix param passing
Mar 26, 2024
95f8cc7
fix param passing
Mar 26, 2024
02c949a
add int8_kv_cache.rst to toctree
Mar 26, 2024
f9fed66
relax int8 kv quant tolerance
Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
kv_quant_params_path=args.kv_quant_params_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)
Expand Down Expand Up @@ -126,10 +127,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
choices=['auto', 'fp8_e5m2', 'int8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--kv-quant-params-path",
type=str,
default=None,
help='Path to scales and zero points of kv cache quantizaiton '
'when kv cache dtype is int8.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run_vllm(
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
kv_quant_params_path=args.kv_quant_params_path,
device=device,
enable_prefix_caching=enable_prefix_caching)

Expand Down Expand Up @@ -300,10 +301,16 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8_e5m2", "int8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--kv-quant-params-path",
type=str,
default=None,
help='Path to scales and zero points of kv cache quantizaiton '
'when kv cache dtype is int8.')
parser.add_argument(
"--device",
type=str,
Expand Down
22 changes: 21 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def main(
device=device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare kv quant parameters for kv_cache_dtype=int8.
# NOTE(zhangying): These parameters only work when kv_cache_dtype is int8.
# They have no influence on other kv_cache_dtypes, like auto and fp8_e5m2.
# For Llama-13B, we find that the key scale distribution in [0.05, 0.15],
# the value scale distribution range is [0.005, 0.10],
# the key zero point distribution range is [-1.5, 1.5],
# the value zero point distribution range is [-2.0, 2.0].
k_scale = random.random() * 0.10 + 0.05
v_scale = random.random() * 0.095 + 0.005
k_zp = random.random() * 3.0 - 1.5
v_zp = random.random() * 4.0 - 2.0

# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
Expand Down Expand Up @@ -112,6 +124,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
k_zp,
v_scale,
v_zp,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -130,6 +146,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
k_zp,
v_scale,
v_zp,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down Expand Up @@ -179,7 +199,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8_e5m2", "int8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
Expand Down
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_int8.cuh"
#include "dtype_fp8_e5m2.cuh"
Loading
Loading