[Bugfix][Flashinfer] fix VLLM_USE_TRTLLM_ATTENTION issue for models with diff hyperparameters (#25924)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-10-09 03:54:48 +08:00 committed by GitHub
parent 76879cc160
commit b82f4307c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 36 deletions

View File

@ -37,7 +37,6 @@ from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (
can_use_trtllm_attention,
flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention,
)
from vllm.v1.attention.backends.utils import (
@ -323,15 +322,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
if supports_trtllm_attention() and not flashinfer_disable_q_quantization():
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if can_use_trtllm and not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
supports_spec_as_decode = can_use_trtllm_attention(
self.num_qo_heads, self.num_kv_heads
)
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
self._cascade_wrapper = None # Wrapper for cascade attention
@ -344,7 +341,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.window_left = self.global_hyperparameters.window_left
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
self.has_sinks = self.global_hyperparameters.has_sinks
if self.has_sinks and not supports_trtllm_attention():
if self.has_sinks and not can_use_trtllm:
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
@ -548,16 +545,30 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs."
if not (prefill_use_trtllm and decode_use_trtllm):
if self.has_sinks:
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention "
"on earlier GPUs."
)
if not self.global_hyperparameters.has_same_window_lefts:
raise ValueError(
"Window left is not the same for all layers. "
"One potential fix is to set disable_sliding_window=True"
)
assert self.global_hyperparameters.has_same_all_params, (
"FlashInfer backend currently only supports models in which "
"all layers share the same values for the following "
"hyperparameters: `window_left`, `logits_soft_cap`, "
"`sm_scale`."
)
# If TRTLLM attention is not used, the q quantization is not supported.
# Fall back to use model dtype.
if not (prefill_use_trtllm and decode_use_trtllm):
# The q quantization is not supported for non-trtllm attention,
# fall back to model dtype.
self.q_data_type = self.model_config.dtype
attn_metadata = FlashInferMetadata(
@ -772,9 +783,7 @@ class FlashInferImpl(AttentionImpl):
)
self.sinks = sinks
self.support_trtllm_attn = (
supports_trtllm_attention() and num_heads % num_kv_heads == 0
)
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None

View File

@ -4,7 +4,7 @@ import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, fields, make_dataclass
from dataclasses import dataclass, field, fields, make_dataclass
from typing import (
TYPE_CHECKING,
Any,
@ -233,7 +233,7 @@ class AttentionCGSupport(enum.Enum):
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
@ -395,6 +395,9 @@ class PerLayerParameters:
logits_soft_cap: Optional[float]
sm_scale: float
has_sinks: bool = False
# has same params for all layers
has_same_window_lefts: Optional[bool] = field(default=None, compare=False)
has_same_all_params: Optional[bool] = field(default=None, compare=False)
def get_per_layer_parameters(
@ -446,20 +449,12 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
# trtllm attention doesn't need global hyper params so disable the check
if not envs.VLLM_USE_TRTLLM_ATTENTION:
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. "
"One potential fix is to set disable_sliding_window=True"
)
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`."
)
global_params.has_same_window_lefts = all(
params.window_left == global_params.window_left for params in param_sets
)
global_params.has_same_all_params = all(
params == global_params for params in param_sets
)
return global_params
@ -925,8 +920,8 @@ def create_fast_prefill_custom_backend(
):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name, getattr(metadata, field.name))
for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name))
# Set additional fields that will be used in model code
assert (