mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 20:15:21 +08:00
[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (#24600)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
9f882d8791
commit
e67a79db03
@ -1223,9 +1223,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_CUDNN_PREFILL":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||
|
||||
# If set to 1, use the TRTLLM attention backend in flashinfer.
|
||||
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
|
||||
# If set to 0/False, use the default attention backend in flashinfer.
|
||||
# If not set, auto-detect the attention backend in flashinfer.
|
||||
"VLLM_USE_TRTLLM_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
|
||||
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),
|
||||
|
||||
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
|
||||
|
||||
@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
|
||||
"""Cache result which only depends on the environment"""
|
||||
# This is a lambda, call it once
|
||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
|
||||
def supports_trtllm_attention() -> bool:
|
||||
"""
|
||||
TRTLLM attention is supported if the platform is SM100 and
|
||||
NVIDIA artifactory is accessible
|
||||
"""
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
if not (current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory()):
|
||||
return False, env_value
|
||||
return current_platform.is_device_capability(
|
||||
100) and has_nvidia_artifactory()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
|
||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
use_trtllm = (env_value == "1")
|
||||
if use_trtllm:
|
||||
logger.info_once("Using TRTLLM attention.")
|
||||
return use_trtllm, env_value
|
||||
return env_value
|
||||
|
||||
return True, None
|
||||
|
||||
def force_use_trtllm_attention() -> Optional[bool]:
|
||||
"""
|
||||
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||
return ``True`` if TRTLLM attention is forced to be used,
|
||||
return ``False`` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
@ -185,18 +188,38 @@ def use_trtllm_attention(
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
has_sinks: bool = False,
|
||||
) -> bool:
|
||||
use_trtllm, env_value = supports_trtllm_attention()
|
||||
if not use_trtllm:
|
||||
"""Return ``True`` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
|
||||
# Environment variable is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
# The platform is not supported
|
||||
if not supports_trtllm_attention():
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
|
||||
return False
|
||||
|
||||
# The combination of query and key heads is not supported
|
||||
if num_qo_heads % num_kv_heads != 0:
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
if has_sinks:
|
||||
raise RuntimeError(
|
||||
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
|
||||
"Use kv_cache_dtype=auto for now.")
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
@ -207,15 +230,17 @@ def use_trtllm_attention(
|
||||
"Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
if env_value is None:
|
||||
if force_use_trtllm is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||
return True
|
||||
|
||||
|
||||
@ -367,6 +392,7 @@ __all__ = [
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
|
||||
@ -282,7 +282,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
if supports_trtllm_attention()[0] and \
|
||||
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
if supports_trtllm_attention() and \
|
||||
not flashinfer_disable_q_quantization():
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
@ -298,7 +302,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.window_left = self.global_hyperparameters.window_left
|
||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||
if self.has_sinks and not supports_trtllm_attention()[0]:
|
||||
if self.has_sinks and not supports_trtllm_attention():
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
@ -477,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=True,
|
||||
has_sinks=self.has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
@ -492,13 +494,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=False,
|
||||
has_sinks=self.has_sinks)
|
||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
"earlier GPUs.")
|
||||
|
||||
# If TRTLLM attention is not used, the q quantization is not supported.
|
||||
# Fall back to use model dtype.
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
q_data_type=self.q_data_type,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user