diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5f0a540135402..0cd95e0749d13 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -294,3 +294,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype != "auto" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5aca10079f9be..0e331efa6a392 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,11 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionType, + is_quantized_kv_cache) +# yapf: enable from vllm.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, get_flash_attn_version, @@ -626,6 +630,9 @@ class FlashAttentionImpl(AttentionImpl): self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttention with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 273c69b63ec63..5d0c230933105 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata, @@ -207,6 +208,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): "are not implemented for " "FlashMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -215,8 +220,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 9eb533685dbd2..f948fbc0a1096 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -15,7 +15,8 @@ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata) @@ -158,6 +159,10 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): "are not implemented for " "HPUAttentionImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "HPUAttention with FP8 KV cache not yet supported") + def forward( self, layer: AttentionLayer, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index b4879af4cf20e..d3c61ea26a02a 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -9,7 +9,8 @@ import torch from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -145,7 +146,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError( "IPEX backend does not support FP8 KV cache. " "Please use xFormers backend instead.") diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b61dfe63ddcaa..2ee66ab9e966e 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -8,7 +8,8 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops. from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState @@ -119,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Alibi slopes is not supported.") if sliding_window is not None: raise NotImplementedError("Sliding window is not supported.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 25fe6ed95c5df..37dd75da27596 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention +# yapf conflicts with isort for this block +# yapf: disable from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionType, + is_quantized_kv_cache) +# yapf: enable from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata @@ -427,7 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError( "Torch SDPA backend does not support FP8 KV cache. " "Please use xFormers backend instead.") diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 08e8226ab04c0..61e5c76d9fda3 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Type import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) @@ -58,6 +59,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): "are not implemented for " "TritonMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -66,8 +71,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e7c2fd412eb2d..db80e52bf0738 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,7 +7,8 @@ import numpy as np import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.logger import init_logger @@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl): else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttention V1 with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d5bf9cd22f1c6..143bfe35bb5e5 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -5,7 +5,8 @@ from typing import Any, Optional import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): "are not implemented for " "FlashMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA V1 with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") - q = torch.cat([q_nope, q_pe], dim=-1)\ .unsqueeze(1) # Add seqlen dim of 1 (decode) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index cef7a3a9a7274..8e7e4f10b81b8 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -4,7 +4,8 @@ from typing import Any, Optional import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): "are not implemented for " "TritonMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA V1 with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor,