diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a8d78daa32a1..4d6f4b471a3a 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -421,7 +421,9 @@ def test_attention_quant_pattern( ] if any(attn_fusion_supported): # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + # Note: fully_replaced=False because query quant ops remain in graph. + # Only output quant ops are fused into attention. + test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 421b0c4beb37..fb2db4d0b0ec 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -41,14 +41,6 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False - # Whether this backend supports receiving pre-quantized query input. - # If True, the attention layer will handle query quantization instead - # of the backend, allowing torch.compile to fuse quantization with - # previous operations. - # Needs to be worked through for all backends - # https://github.com/vllm-project/vllm/issues/25584 - supports_quant_query_input: bool = False - @staticmethod @abstractmethod def get_name() -> str: @@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]): """ return False + def supports_quant_query_input(self) -> bool: + """ + Check if this attention implementation supports pre-quantized query input. + + When True, the attention layer will quantize queries before passing them + to this backend, allowing torch.compile to fuse the quantization with + previous operations. This is typically supported when using FP8 KV cache + with compatible attention kernels (e.g., TRT-LLM). + TODO add support to more backends: + https://github.com/vllm-project/vllm/issues/25584 + + Returns: + bool: True if the implementation can accept pre-quantized queries. + """ + return False + class MLAAttentionImpl(AttentionImpl[T], Generic[T]): @abstractmethod diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 16c5799f7d0b..9f879f7272e2 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils import GiB_bytes, direct_register_custom_op +FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase): self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") - and self.attn_backend.supports_quant_query_input + and self.impl.supports_quant_query_input() ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) @@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase): """ if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) - output_dtype = query.dtype if self.query_quant is not None: # quantizing with a simple torch operation enables @@ -338,7 +338,10 @@ class Attention(nn.Module, AttentionLayerBase): # Otherwise queries are quantized using custom ops # which causes decoding overheads assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} - query, _ = self.query_quant(query, self._q_scale) + + # check if query quantization is supported + if self.impl.supports_quant_query_input(): + query, _ = self.query_quant(query, self._q_scale) if self.use_output: output_shape = output_shape if output_shape is not None else query.shape diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9e0c125d9edb..087f995e0528 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -49,7 +49,6 @@ logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl): "heads in the layer" ) + def supports_quant_query_input(self) -> bool: + return True + def forward( self, layer: torch.nn.Module, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ee32f7e2904f..eb9f6a280d8f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor -from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl): and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) ) + def supports_quant_query_input(self) -> bool: + if flashinfer_disable_q_quantization(): + return False + + return self.support_trtllm_attn + def forward( self, layer: torch.nn.Module, @@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl): # Profiling run. return output.fill_(0) + # Ensure query dtype matches the expected dtype from attention metadata + assert attn_metadata.q_data_type == query.dtype, ( + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " + f"got {query.dtype}" + ) + if self.bmm1_scale is None: self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale @@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl): elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query - if attn_metadata.q_data_type == FP8_DTYPE: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale, - ) - query = query.reshape((num_tokens, num_heads, head_size)) - # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9746a0eb58bd..b1d34dbfd172 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import ( ) from vllm.v1.kv_cache_interface import AttentionSpec -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops - logger = init_logger(__name__) @@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym + def supports_quant_query_input(self) -> bool: + return current_platform.is_cuda() + def __init__( self, num_heads: int, @@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl): if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." ) - if current_platform.is_cuda(): - # Skip Q quantization on ROCm and XPU, enable this on cuda - # only, since dequantizing back to f32 in the attention kernel - # is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale, - ) - query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens