[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (#14221)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-03-06 17:18:29 -05:00 committed by GitHub
parent 6b2ef5cd17
commit 6832707e90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 59 additions and 20 deletions

View File

@ -294,3 +294,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"

View File

@ -8,11 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version,
@ -626,6 +630,9 @@ class FlashAttentionImpl(AttentionImpl):
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0

View File

@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
@ -207,6 +208,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
@ -215,8 +220,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None

View File

@ -15,7 +15,8 @@ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
@ -158,6 +159,10 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"are not implemented for "
"HPUAttentionImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"HPUAttention with FP8 KV cache not yet supported")
def forward(
self,
layer: AttentionLayer,

View File

@ -9,7 +9,8 @@ import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -145,7 +146,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")

View File

@ -8,7 +8,8 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
@ -119,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")

View File

@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
@ -427,7 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")

View File

@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
@ -58,6 +59,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
@ -66,8 +71,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None

View File

@ -7,7 +7,8 @@ import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.logger import init_logger
@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0

View File

@ -5,7 +5,8 @@ from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)

View File

@ -4,7 +4,8 @@ from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,