[torch.compile] Make Query Quantization Fusable (#24914)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler 2025-09-25 15:25:12 +02:00 committed by GitHub
parent 6c340da4df
commit 69a8c8e99a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 8 deletions

View File

@ -31,6 +31,14 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:

View File

@ -22,7 +22,10 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
@ -247,6 +250,13 @@ class Attention(nn.Module, AttentionLayerBase):
"This may be caused by insufficient memory to allocate "
"kv cache.") from e
# for attn backends supporting query quantization
self.query_quant = None
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
def forward(
self,
query: torch.Tensor,
@ -270,11 +280,22 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=query.dtype,
dtype=output_dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA

View File

@ -7,7 +7,6 @@ from typing import Optional
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
@ -38,6 +37,7 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
@ -506,16 +506,11 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc