Move query quantization to attention layer for Flashinfer & Triton. (#26534)

Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Adrian Abeyta 2025-10-15 18:01:38 -05:00 committed by GitHub
parent e5b438a247
commit 0a9ef0cfce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 43 additions and 38 deletions

View File

@ -421,7 +421,9 @@ def test_attention_quant_pattern(
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
# Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

View File

@ -41,14 +41,6 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:
@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod

View File

@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.attn_backend.supports_quant_query_input
and self.impl.supports_quant_query_input()
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase):
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
@ -338,7 +338,10 @@ class Attention(nn.Module, AttentionLayerBase):
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape

View File

@ -49,7 +49,6 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer"
)
def supports_quant_query_input(self) -> bool:
return True
def forward(
self,
layer: torch.nn.Module,

View File

@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
)
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
def forward(
self,
layer: torch.nn.Module,
@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl):
# Profiling run.
return output.fill_(0)
# Ensure query dtype matches the expected dtype from attention metadata
assert attn_metadata.q_data_type == query.dtype, (
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
f"got {query.dtype}"
)
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl):
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead

View File

@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__(
self,
num_heads: int,
@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens