mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[torch.compile] Make Query Quantization Fusable (#24914)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
parent
6c340da4df
commit
69a8c8e99a
@ -31,6 +31,14 @@ class AttentionBackend(ABC):
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
# Whether this backend supports receiving pre-quantized query input.
|
||||
# If True, the attention layer will handle query quantization instead
|
||||
# of the backend, allowing torch.compile to fuse quantization with
|
||||
# previous operations.
|
||||
# Needs to be worked through for all backends
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@ -22,7 +22,10 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
@ -247,6 +250,13 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"kv cache.") from e
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.kv_cache_dtype.startswith(
|
||||
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||
self.query_quant = QuantFP8(static=True,
|
||||
group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -270,11 +280,22 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
dtype=output_dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
@ -38,6 +37,7 @@ logger = init_logger(__name__)
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
supports_quant_query_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
@ -506,16 +506,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
self.kv_cache_dtype)
|
||||
key_cache = key_cache.view(dtype)
|
||||
value_cache = value_cache.view(dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user