From 9a31a817a85ac4249bf82dd8b6f90ef6b8e81fef Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 May 2024 15:42:29 -0700 Subject: [PATCH] [Bugfix] Fix FP8 KV cache support (#4869) --- vllm/attention/backends/flash_attn.py | 10 +++++----- vllm/attention/backends/flashinfer.py | 10 +++++----- vllm/attention/backends/rocm_flash_attn.py | 10 +++++----- vllm/attention/backends/torch_sdpa.py | 10 +++++----- vllm/attention/backends/xformers.py | 10 +++++----- vllm/attention/layer.py | 2 +- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5d1f65819ed4..856f39974137 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5f9fd586fb70..7210fefbd816 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1a94dc359635..bb828d6fc04f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a3f72b9c9456..a19c97e1e0e3 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fc46af054de4..96169da6cf92 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 126692d8c9b4..4299726bdca4 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -48,7 +48,7 @@ class Attention(nn.Module): block_size) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window) + alibi_slopes, sliding_window, kv_cache_dtype) def forward( self,