[PERF] Remove TRTLLM Gen attn kernel limitation max_seq_len <=131072 (#28755)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson 2025-11-15 14:13:41 +04:00 committed by GitHub
parent 638e4196d1
commit 173b356abf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 19 deletions

View File

@ -483,21 +483,6 @@ class VllmConfig:
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and self.model_config.max_model_len > 131072
and not self.model_config.use_mla
):
# Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
logger.warning_once(
"NVIDIA Blackwell TRTLLM attention cannot support "
"max_model_len >= 131072 (found "
f"{self.model_config.max_model_len}), causing dynamic "
"dispatching that breaks full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:

View File

@ -319,14 +319,12 @@ def use_trtllm_attention(
# Environment variable not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
use_trtllm = kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
else:
# Decode auto-detection
use_trtllm = (
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
)
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm