[NVIDIA] flashinfer TRTLLM attention prefill token limit (#25998)

Signed-off-by: jasonlizhengjian <jason.li@centml.ai>
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
This commit is contained in:
Jason Li 2025-10-05 16:24:37 -04:00 committed by GitHub
parent 9c3c21c519
commit 6b6e98775f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -283,11 +283,18 @@ def use_trtllm_attention(
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
use_trtllm = (
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
)
if use_trtllm:
logger.warning_once("Using TRTLLM attention (auto-detected).")
if is_prefill:
# Prefill auto-detection
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
else:
# Decode auto-detection
use_trtllm = (
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
)
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it