mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:45:58 +08:00
[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (#24600)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
9f882d8791
commit
e67a79db03
@ -1223,9 +1223,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_CUDNN_PREFILL":
|
"VLLM_USE_CUDNN_PREFILL":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||||
|
|
||||||
# If set to 1, use the TRTLLM attention backend in flashinfer.
|
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
|
||||||
|
# If set to 0/False, use the default attention backend in flashinfer.
|
||||||
|
# If not set, auto-detect the attention backend in flashinfer.
|
||||||
"VLLM_USE_TRTLLM_ATTENTION":
|
"VLLM_USE_TRTLLM_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
|
||||||
|
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),
|
||||||
|
|
||||||
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
|
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
|
||||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
|
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
|
||||||
|
|||||||
@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
|
def supports_trtllm_attention() -> bool:
|
||||||
"""Cache result which only depends on the environment"""
|
"""
|
||||||
# This is a lambda, call it once
|
TRTLLM attention is supported if the platform is SM100 and
|
||||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
NVIDIA artifactory is accessible
|
||||||
|
"""
|
||||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
if not (current_platform.is_device_capability(100)
|
return current_platform.is_device_capability(
|
||||||
and has_nvidia_artifactory()):
|
100) and has_nvidia_artifactory()
|
||||||
return False, env_value
|
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
|
||||||
|
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||||
if env_value is not None:
|
if env_value is not None:
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
# Environment variable is set - respect it
|
return env_value
|
||||||
# Making the conditional check for zero because
|
|
||||||
# the path is automatically enabled if the batch size condition
|
|
||||||
# is satisfied.
|
|
||||||
use_trtllm = (env_value == "1")
|
|
||||||
if use_trtllm:
|
|
||||||
logger.info_once("Using TRTLLM attention.")
|
|
||||||
return use_trtllm, env_value
|
|
||||||
|
|
||||||
return True, None
|
|
||||||
|
def force_use_trtllm_attention() -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||||
|
return ``True`` if TRTLLM attention is forced to be used,
|
||||||
|
return ``False`` if TRTLLM attention is forced to be not used.
|
||||||
|
"""
|
||||||
|
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||||
|
|
||||||
|
|
||||||
def use_trtllm_attention(
|
def use_trtllm_attention(
|
||||||
@ -185,18 +188,38 @@ def use_trtllm_attention(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
q_dtype: torch.dtype,
|
q_dtype: torch.dtype,
|
||||||
is_prefill: bool,
|
|
||||||
has_sinks: bool = False,
|
has_sinks: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
use_trtllm, env_value = supports_trtllm_attention()
|
"""Return ``True`` if TRTLLM attention is used."""
|
||||||
if not use_trtllm:
|
force_use_trtllm = force_use_trtllm_attention()
|
||||||
|
|
||||||
|
# Environment variable is set to 0 - respect it
|
||||||
|
if force_use_trtllm is not None and not force_use_trtllm:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# The platform is not supported
|
||||||
|
if not supports_trtllm_attention():
|
||||||
|
if force_use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"TRTLLM attention is not supported on this platform, "
|
||||||
|
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# The combination of query and key heads is not supported
|
||||||
if num_qo_heads % num_kv_heads != 0:
|
if num_qo_heads % num_kv_heads != 0:
|
||||||
|
if force_use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"TRTLLM attention is not supported for this combination of "
|
||||||
|
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Must use TRTLLM attention if query is FP8 quantized
|
# Must use TRTLLM attention if query is FP8 quantized
|
||||||
if q_dtype == current_platform.fp8_dtype():
|
if q_dtype == current_platform.fp8_dtype():
|
||||||
|
if has_sinks:
|
||||||
|
raise RuntimeError(
|
||||||
|
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
|
||||||
|
"Use kv_cache_dtype=auto for now.")
|
||||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -207,15 +230,17 @@ def use_trtllm_attention(
|
|||||||
"Using TRTLLM attention (required for attention sinks).")
|
"Using TRTLLM attention (required for attention sinks).")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if env_value is None:
|
if force_use_trtllm is None:
|
||||||
# Environment variable not set - use auto-detection
|
# Environment variable not set - use auto-detection
|
||||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
|
||||||
and kv_cache_dtype == "auto")
|
and kv_cache_dtype == "auto")
|
||||||
if use_trtllm:
|
if use_trtllm:
|
||||||
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||||
return use_trtllm
|
return use_trtllm
|
||||||
|
|
||||||
# Environment variable is set to 1 - respect it
|
# Environment variable is set to 1 - respect it
|
||||||
|
logger.info_once(
|
||||||
|
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -367,6 +392,7 @@ __all__ = [
|
|||||||
"has_nvidia_artifactory",
|
"has_nvidia_artifactory",
|
||||||
"supports_trtllm_attention",
|
"supports_trtllm_attention",
|
||||||
"use_trtllm_attention",
|
"use_trtllm_attention",
|
||||||
|
"flashinfer_disable_q_quantization",
|
||||||
"flashinfer_scaled_fp4_mm",
|
"flashinfer_scaled_fp4_mm",
|
||||||
"flashinfer_scaled_fp8_mm",
|
"flashinfer_scaled_fp8_mm",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -282,7 +282,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||||
|
|
||||||
if supports_trtllm_attention()[0] and \
|
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
||||||
|
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||||
|
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||||
|
# if TRTLLM attention kernel is not used when building attn metadata
|
||||||
|
if supports_trtllm_attention() and \
|
||||||
not flashinfer_disable_q_quantization():
|
not flashinfer_disable_q_quantization():
|
||||||
self.q_data_type = self.kv_cache_dtype
|
self.q_data_type = self.kv_cache_dtype
|
||||||
else:
|
else:
|
||||||
@ -298,7 +302,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.window_left = self.global_hyperparameters.window_left
|
self.window_left = self.global_hyperparameters.window_left
|
||||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||||
if self.has_sinks and not supports_trtllm_attention()[0]:
|
if self.has_sinks and not supports_trtllm_attention():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashInfer backend currently does not support attention "
|
"FlashInfer backend currently does not support attention "
|
||||||
"sinks, please use trtllm on blackwell or flash attention on "
|
"sinks, please use trtllm on blackwell or flash attention on "
|
||||||
@ -477,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
paged_kv_last_page_len_np,
|
paged_kv_last_page_len_np,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
|
||||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
num_prefill_tokens,
|
num_prefill_tokens,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.cache_dtype,
|
self.cache_dtype,
|
||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
is_prefill=True,
|
|
||||||
has_sinks=self.has_sinks)
|
has_sinks=self.has_sinks)
|
||||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
@ -492,13 +494,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.cache_dtype,
|
self.cache_dtype,
|
||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
is_prefill=False,
|
|
||||||
has_sinks=self.has_sinks)
|
has_sinks=self.has_sinks)
|
||||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashInfer backend currently does not support attention "
|
"FlashInfer backend currently does not support attention "
|
||||||
"sinks, please use trtllm on blackwell or flash attention on "
|
"sinks, please use trtllm on blackwell or flash attention on "
|
||||||
"earlier GPUs.")
|
"earlier GPUs.")
|
||||||
|
|
||||||
|
# If TRTLLM attention is not used, the q quantization is not supported.
|
||||||
|
# Fall back to use model dtype.
|
||||||
|
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||||
|
self.q_data_type = self.model_config.dtype
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user