mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:14:57 +08:00
Move query quantization to attention layer for Flashinfer & Triton. (#26534)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e5b438a247
commit
0a9ef0cfce
@ -421,7 +421,9 @@ def test_attention_quant_pattern(
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
|
||||
# Note: fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
@ -41,14 +41,6 @@ class AttentionBackend(ABC):
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
# Whether this backend supports receiving pre-quantized query input.
|
||||
# If True, the attention layer will handle query quantization instead
|
||||
# of the backend, allowing torch.compile to fuse quantization with
|
||||
# previous operations.
|
||||
# Needs to be worked through for all backends
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
"""
|
||||
return False
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
"""
|
||||
Check if this attention implementation supports pre-quantized query input.
|
||||
|
||||
When True, the attention layer will quantize queries before passing them
|
||||
to this backend, allowing torch.compile to fuse the quantization with
|
||||
previous operations. This is typically supported when using FP8 KV cache
|
||||
with compatible attention kernels (e.g., TRT-LLM).
|
||||
TODO add support to more backends:
|
||||
https://github.com/vllm-project/vllm/issues/25584
|
||||
|
||||
Returns:
|
||||
bool: True if the implementation can accept pre-quantized queries.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
@abstractmethod
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
|
||||
@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.query_quant = None
|
||||
if (
|
||||
self.kv_cache_dtype.startswith("fp8")
|
||||
and self.attn_backend.supports_quant_query_input
|
||||
and self.impl.supports_quant_query_input()
|
||||
):
|
||||
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
@ -338,7 +338,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input():
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
|
||||
@ -49,7 +49,6 @@ logger = init_logger(__name__)
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supports_quant_query_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"heads in the layer"
|
||||
)
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
||||
)
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
if flashinfer_disable_q_quantization():
|
||||
return False
|
||||
|
||||
return self.support_trtllm_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Ensure query dtype matches the expected dtype from attention metadata
|
||||
assert attn_metadata.q_data_type == query.dtype, (
|
||||
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
|
||||
f"got {query.dtype}"
|
||||
)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
|
||||
@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
elif output.dtype == FP4_DTYPE:
|
||||
self.o_sf_scale = layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
if attn_metadata.q_data_type == FP8_DTYPE:
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale,
|
||||
)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
|
||||
@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return current_platform.is_cuda()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||
# only, since dequantizing back to f32 in the attention kernel
|
||||
# is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale,
|
||||
)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user