[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
elvischenv 2025-09-18 06:36:29 +08:00 committed by GitHub
parent 9f882d8791
commit e67a79db03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 29 deletions

View File

@ -1223,9 +1223,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM attention backend in flashinfer.
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
# If set to 0/False, use the default attention backend in flashinfer.
# If not set, auto-detect the attention backend in flashinfer.
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":

View File

@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:
@functools.cache
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
"""Cache result which only depends on the environment"""
# This is a lambda, call it once
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
def supports_trtllm_attention() -> bool:
"""
TRTLLM attention is supported if the platform is SM100 and
NVIDIA artifactory is accessible
"""
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False, env_value
return current_platform.is_device_capability(
100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm, env_value
return env_value
return True, None
def force_use_trtllm_attention() -> Optional[bool]:
"""
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
return ``True`` if TRTLLM attention is forced to be used,
return ``False`` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
def use_trtllm_attention(
@ -185,18 +188,38 @@ def use_trtllm_attention(
max_seq_len: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
) -> bool:
use_trtllm, env_value = supports_trtllm_attention()
if not use_trtllm:
"""Return ``True`` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False
# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
return False
# The combination of query and key heads is not supported
if num_qo_heads % num_kv_heads != 0:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
)
return False
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
raise RuntimeError(
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
"Use kv_cache_dtype=auto for now.")
logger.info_once("Using TRTLLM attention (query is quantized).")
return True
@ -207,15 +230,17 @@ def use_trtllm_attention(
"Using TRTLLM attention (required for attention sinks).")
return True
if env_value is None:
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once("Using TRTLLM attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
return True
@ -367,6 +392,7 @@ __all__ = [
"has_nvidia_artifactory",
"supports_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]

View File

@ -282,7 +282,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
if supports_trtllm_attention()[0] and \
# Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
if supports_trtllm_attention() and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
@ -298,7 +302,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.window_left = self.global_hyperparameters.window_left
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
self.has_sinks = self.global_hyperparameters.has_sinks
if self.has_sinks and not supports_trtllm_attention()[0]:
if self.has_sinks and not supports_trtllm_attention():
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
@ -477,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_np,
)
# Check if any layer uses sinks (requires TRTLLM attention)
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=self.has_sinks)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
@ -492,13 +494,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs.")
# If TRTLLM attention is not used, the q quantization is not supported.
# Fall back to use model dtype.
if not (prefill_use_trtllm and decode_use_trtllm):
self.q_data_type = self.model_config.dtype
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,