Move query quantization to attention layer for Flashinfer & Triton. (#26534)

Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Adrian Abeyta 2025-10-15 18:01:38 -05:00 committed by GitHub
parent e5b438a247
commit 0a9ef0cfce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 43 additions and 38 deletions

View File

@ -421,7 +421,9 @@ def test_attention_quant_pattern(
] ]
if any(attn_fusion_supported): if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion # Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) # Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# access the underlying `AttnFusionPass` on the `LazyInitPass` # access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

View File

@ -41,14 +41,6 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph. # makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False accept_output_buffer: bool = False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_name() -> str: def get_name() -> str:
@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]):
""" """
return False return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]): class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod @abstractmethod

View File

@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op from vllm.utils import GiB_bytes, direct_register_custom_op
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
USE_XFORMERS_OPS = None USE_XFORMERS_OPS = None
@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None self.query_quant = None
if ( if (
self.kv_cache_dtype.startswith("fp8") self.kv_cache_dtype.startswith("fp8")
and self.attn_backend.supports_quant_query_input and self.impl.supports_quant_query_input()
): ):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase):
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
# quantizing with a simple torch operation enables # quantizing with a simple torch operation enables
@ -338,6 +338,9 @@ class Attention(nn.Module, AttentionLayerBase):
# Otherwise queries are quantized using custom ops # Otherwise queries are quantized using custom ops
# which causes decoding overheads # which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale) query, _ = self.query_quant(query, self._q_scale)
if self.use_output: if self.use_output:

View File

@ -49,7 +49,6 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supports_quant_query_input: bool = True
@classmethod @classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_dtypes(cls) -> list[torch.dtype]:
@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer" "heads in the layer"
) )
def supports_quant_query_input(self) -> bool:
return True
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,

View File

@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
) )
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl):
# Profiling run. # Profiling run.
return output.fill_(0) return output.fill_(0)
# Ensure query dtype matches the expected dtype from attention metadata
assert attn_metadata.q_data_type == query.dtype, (
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
f"got {query.dtype}"
)
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl):
elif output.dtype == FP4_DTYPE: elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
# IMPORTANT! # IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead

View File

@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
if key_cache.dtype != self.fp8_dtype: if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, ( assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported." "A non 1.0 q_scale is not currently supported."
) )
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens seqused_k = attn_metadata.seq_lens