mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 06:18:41 +08:00
[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (#14221)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6b2ef5cd17
commit
6832707e90
@ -294,3 +294,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
|
||||
@ -8,11 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_flash_attn_version,
|
||||
@ -626,6 +630,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.sliding_window = ((sliding_window - 1,
|
||||
0) if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttention with FP8 KV cache not yet supported")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
|
||||
@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
@ -207,6 +208,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
@ -215,8 +220,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
@ -15,7 +15,8 @@ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||
HPUPagedAttentionMetadata)
|
||||
@ -158,6 +159,10 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
"are not implemented for "
|
||||
"HPUAttentionImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"HPUAttention with FP8 KV cache not yet supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
|
||||
@ -9,7 +9,8 @@ import torch
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
@ -145,7 +146,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"IPEX backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
|
||||
@ -8,7 +8,8 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
@ -119,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
@ -427,7 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
@ -58,6 +59,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
@ -66,8 +71,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
@ -7,7 +7,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.logger import init_logger
|
||||
@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttention V1 with FP8 KV cache not yet supported")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
|
||||
@ -5,7 +5,8 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user