mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 09:05:28 +08:00
[torch.compile] Make Query Quantization Fusable (#24914)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
parent
6c340da4df
commit
69a8c8e99a
@ -31,6 +31,14 @@ class AttentionBackend(ABC):
|
|||||||
# makes sure the output tensor is allocated inside the cudagraph.
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
accept_output_buffer: bool = False
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
|
# Whether this backend supports receiving pre-quantized query input.
|
||||||
|
# If True, the attention layer will handle query quantization instead
|
||||||
|
# of the backend, allowing torch.compile to fuse quantization with
|
||||||
|
# previous operations.
|
||||||
|
# Needs to be worked through for all backends
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
supports_quant_query_input: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
|
|||||||
@ -22,7 +22,10 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape)
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
@ -247,6 +250,13 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
"This may be caused by insufficient memory to allocate "
|
"This may be caused by insufficient memory to allocate "
|
||||||
"kv cache.") from e
|
"kv cache.") from e
|
||||||
|
|
||||||
|
# for attn backends supporting query quantization
|
||||||
|
self.query_quant = None
|
||||||
|
if self.kv_cache_dtype.startswith(
|
||||||
|
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||||
|
self.query_quant = QuantFP8(static=True,
|
||||||
|
group_shape=GroupShape.PER_TENSOR)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -270,11 +280,22 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.enable_kv_scales_calculation:
|
if attn_metadata.enable_kv_scales_calculation:
|
||||||
self.calc_kv_scales(query, key, value)
|
self.calc_kv_scales(query, key, value)
|
||||||
|
|
||||||
|
output_dtype = query.dtype
|
||||||
|
if self.query_quant is not None:
|
||||||
|
# quantizing with a simple torch operation enables
|
||||||
|
# torch.compile to fuse this into previous ops
|
||||||
|
# which reduces overheads during decoding.
|
||||||
|
# Otherwise queries are quantized using custom ops
|
||||||
|
# which causes decoding overheads
|
||||||
|
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||||
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
output_shape = (output_shape
|
output_shape = (output_shape
|
||||||
if output_shape is not None else query.shape)
|
if output_shape is not None else query.shape)
|
||||||
output = torch.zeros(output_shape,
|
output = torch.zeros(output_shape,
|
||||||
dtype=query.dtype,
|
dtype=output_dtype,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
hidden_size = output_shape[-1]
|
hidden_size = output_shape[-1]
|
||||||
# We skip reshaping query, key and value tensors for the MLA
|
# We skip reshaping query, key and value tensors for the MLA
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
@ -38,6 +37,7 @@ logger = init_logger(__name__)
|
|||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
supports_quant_query_input: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
@ -506,16 +506,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
# queries are quantized in the attention layer
|
||||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
self.kv_cache_dtype)
|
self.kv_cache_dtype)
|
||||||
key_cache = key_cache.view(dtype)
|
key_cache = key_cache.view(dtype)
|
||||||
value_cache = value_cache.view(dtype)
|
value_cache = value_cache.view(dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
|
||||||
query, _ = ops.scaled_fp8_quant(
|
|
||||||
query.reshape(
|
|
||||||
(num_tokens, num_heads * head_size)).contiguous(),
|
|
||||||
layer._q_scale)
|
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
|
||||||
|
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user