From b82f4307c9b9910b77cdb5043fc0a81ce4c459a1 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 9 Oct 2025 03:54:48 +0800 Subject: [PATCH] [Bugfix][Flashinfer] fix VLLM_USE_TRTLLM_ATTENTION issue for models with diff hyperparameters (#25924) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- vllm/v1/attention/backends/flashinfer.py | 45 ++++++++++++++---------- vllm/v1/attention/backends/utils.py | 31 +++++++--------- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 38cf0ca567331..55186e2938c3d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -37,7 +37,6 @@ from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import ( can_use_trtllm_attention, flashinfer_disable_q_quantization, - supports_trtllm_attention, use_trtllm_attention, ) from vllm.v1.attention.backends.utils import ( @@ -323,15 +322,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata - if supports_trtllm_attention() and not flashinfer_disable_q_quantization(): + can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if can_use_trtllm and not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype - supports_spec_as_decode = can_use_trtllm_attention( - self.num_qo_heads, self.num_kv_heads - ) - self._init_reorder_batch_threshold(1, supports_spec_as_decode) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) self._cascade_wrapper = None # Wrapper for cascade attention @@ -344,7 +341,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - if self.has_sinks and not supports_trtllm_attention(): + if self.has_sinks and not can_use_trtllm: raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " @@ -548,16 +545,30 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) - if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): - raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs." + + if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." ) - # If TRTLLM attention is not used, the q quantization is not supported. - # Fall back to use model dtype. - if not (prefill_use_trtllm and decode_use_trtllm): + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. self.q_data_type = self.model_config.dtype attn_metadata = FlashInferMetadata( @@ -772,9 +783,7 @@ class FlashInferImpl(AttentionImpl): ) self.sinks = sinks - self.support_trtllm_attn = ( - supports_trtllm_attention() and num_heads % num_kv_heads == 0 - ) + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None self.o_sf_scale: float | None = None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index add2c3cb8d593..3b71a6505b30b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,7 @@ import abc import enum import functools from abc import abstractmethod -from dataclasses import dataclass, fields, make_dataclass +from dataclasses import dataclass, field, fields, make_dataclass from typing import ( TYPE_CHECKING, Any, @@ -233,7 +233,7 @@ class AttentionCGSupport(enum.Enum): """Cudagraph always supported; supports mixed-prefill-decode""" UNIFORM_BATCH = 2 """Cudagraph supported for batches the only contain query lengths that are - the same, this can be used for spec-decode + the same, this can be used for spec-decode i.e. "decodes" are 1 + num_speculative_tokens""" UNIFORM_SINGLE_TOKEN_DECODE = 1 """Cudagraph supported for batches the only contain query_len==1 decodes""" @@ -395,6 +395,9 @@ class PerLayerParameters: logits_soft_cap: Optional[float] sm_scale: float has_sinks: bool = False + # has same params for all layers + has_same_window_lefts: Optional[bool] = field(default=None, compare=False) + has_same_all_params: Optional[bool] = field(default=None, compare=False) def get_per_layer_parameters( @@ -446,20 +449,12 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - # trtllm attention doesn't need global hyper params so disable the check - if not envs.VLLM_USE_TRTLLM_ATTENTION: - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. " - "One potential fix is to set disable_sliding_window=True" - ) - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all" - "layers share the same values " - "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`." - ) + global_params.has_same_window_lefts = all( + params.window_left == global_params.window_left for params in param_sets + ) + global_params.has_same_all_params = all( + params == global_params for params in param_sets + ) return global_params @@ -925,8 +920,8 @@ def create_fast_prefill_custom_backend( ): def __init__(self, metadata, common_attn_metadata): # Shallow copy all fields in metadata cls - for field in fields(metadata.__class__): - setattr(self, field.name, getattr(metadata, field.name)) + for _field in fields(metadata.__class__): + setattr(self, _field.name, getattr(metadata, _field.name)) # Set additional fields that will be used in model code assert (