From 26904dd78495ad1b18e43d9e52ee62e05cb71d04 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Thu, 29 Aug 2024 01:04:53 -0700 Subject: [PATCH] Update vllm/attention/backends/flashinfer.py Co-authored-by: Cody Yu --- vllm/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8493f8d9eeda..f554fa2805bd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -686,7 +686,7 @@ def forward( ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 - if self.kv_cache_dtype in ["fp8", "fp8_e4m3", "fp8_e5m2"]: + if self.kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype)