mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
Move query quantization to attention layer for Flashinfer & Triton. (#26534)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e5b438a247
commit
0a9ef0cfce
@ -421,7 +421,9 @@ def test_attention_quant_pattern(
|
|||||||
]
|
]
|
||||||
if any(attn_fusion_supported):
|
if any(attn_fusion_supported):
|
||||||
# Check quantization ops in the graph before and after fusion
|
# Check quantization ops in the graph before and after fusion
|
||||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
|
# Note: fully_replaced=False because query quant ops remain in graph.
|
||||||
|
# Only output quant ops are fused into attention.
|
||||||
|
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||||
|
|
||||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||||
|
|||||||
@ -41,14 +41,6 @@ class AttentionBackend(ABC):
|
|||||||
# makes sure the output tensor is allocated inside the cudagraph.
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
accept_output_buffer: bool = False
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
# Whether this backend supports receiving pre-quantized query input.
|
|
||||||
# If True, the attention layer will handle query quantization instead
|
|
||||||
# of the backend, allowing torch.compile to fuse quantization with
|
|
||||||
# previous operations.
|
|
||||||
# Needs to be worked through for all backends
|
|
||||||
# https://github.com/vllm-project/vllm/issues/25584
|
|
||||||
supports_quant_query_input: bool = False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_quant_query_input(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this attention implementation supports pre-quantized query input.
|
||||||
|
|
||||||
|
When True, the attention layer will quantize queries before passing them
|
||||||
|
to this backend, allowing torch.compile to fuse the quantization with
|
||||||
|
previous operations. This is typically supported when using FP8 KV cache
|
||||||
|
with compatible attention kernels (e.g., TRT-LLM).
|
||||||
|
TODO add support to more backends:
|
||||||
|
https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the implementation can accept pre-quantized queries.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
USE_XFORMERS_OPS = None
|
USE_XFORMERS_OPS = None
|
||||||
|
|
||||||
@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.query_quant = None
|
self.query_quant = None
|
||||||
if (
|
if (
|
||||||
self.kv_cache_dtype.startswith("fp8")
|
self.kv_cache_dtype.startswith("fp8")
|
||||||
and self.attn_backend.supports_quant_query_input
|
and self.impl.supports_quant_query_input()
|
||||||
):
|
):
|
||||||
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||||
|
|
||||||
@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
"""
|
"""
|
||||||
if self.calculate_kv_scales:
|
if self.calculate_kv_scales:
|
||||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||||
|
|
||||||
output_dtype = query.dtype
|
output_dtype = query.dtype
|
||||||
if self.query_quant is not None:
|
if self.query_quant is not None:
|
||||||
# quantizing with a simple torch operation enables
|
# quantizing with a simple torch operation enables
|
||||||
@ -338,6 +338,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# Otherwise queries are quantized using custom ops
|
# Otherwise queries are quantized using custom ops
|
||||||
# which causes decoding overheads
|
# which causes decoding overheads
|
||||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||||
|
|
||||||
|
# check if query quantization is supported
|
||||||
|
if self.impl.supports_quant_query_input():
|
||||||
query, _ = self.query_quant(query, self._q_scale)
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
|
|||||||
@ -49,7 +49,6 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supports_quant_query_input: bool = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"heads in the layer"
|
"heads in the layer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def supports_quant_query_input(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
|||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
from flashinfer.utils import FP4Tensor
|
from flashinfer.utils import FP4Tensor
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def supports_quant_query_input(self) -> bool:
|
||||||
|
if flashinfer_disable_q_quantization():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.support_trtllm_attn
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
|
# Ensure query dtype matches the expected dtype from attention metadata
|
||||||
|
assert attn_metadata.q_data_type == query.dtype, (
|
||||||
|
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
|
||||||
|
f"got {query.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
if self.bmm1_scale is None:
|
if self.bmm1_scale is None:
|
||||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||||
|
|
||||||
@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
elif output.dtype == FP4_DTYPE:
|
elif output.dtype == FP4_DTYPE:
|
||||||
self.o_sf_scale = layer._o_scale_float
|
self.o_sf_scale = layer._o_scale_float
|
||||||
|
|
||||||
# Insert FP8 quant for query
|
|
||||||
if attn_metadata.q_data_type == FP8_DTYPE:
|
|
||||||
num_tokens, num_heads, head_size = query.shape
|
|
||||||
query, _ = ops.scaled_fp8_quant(
|
|
||||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
|
||||||
layer._q_scale,
|
|
||||||
)
|
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
|
||||||
|
|
||||||
# IMPORTANT!
|
# IMPORTANT!
|
||||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||||
|
|||||||
@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
elif current_platform.is_xpu():
|
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return quant_key == kFp8StaticTensorSym
|
return quant_key == kFp8StaticTensorSym
|
||||||
|
|
||||||
|
def supports_quant_query_input(self) -> bool:
|
||||||
|
return current_platform.is_cuda()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
if key_cache.dtype != self.fp8_dtype:
|
if key_cache.dtype != self.fp8_dtype:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
|
||||||
assert layer._q_scale_float == 1.0, (
|
assert layer._q_scale_float == 1.0, (
|
||||||
"A non 1.0 q_scale is not currently supported."
|
"A non 1.0 q_scale is not currently supported."
|
||||||
)
|
)
|
||||||
if current_platform.is_cuda():
|
|
||||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
|
||||||
# only, since dequantizing back to f32 in the attention kernel
|
|
||||||
# is not supported.
|
|
||||||
query, _ = ops.scaled_fp8_quant(
|
|
||||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
|
||||||
layer._q_scale,
|
|
||||||
)
|
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
|
||||||
|
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
seqused_k = attn_metadata.seq_lens
|
seqused_k = attn_metadata.seq_lens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user