[Kernel] Simplify get_kv_cache_layout and cache use_trtllm_attention env-dependent bit (#22735)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-08-16 02:14:08 +02:00 committed by GitHub
parent ad0297d113
commit 070da660c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 22 deletions

View File

@ -148,6 +148,31 @@ def has_nvidia_artifactory() -> bool:
return False return False
@functools.cache
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
"""Cache result which only depends on the environment"""
# This is a lambda, call it once
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False, env_value
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm, env_value
return True, None
def use_trtllm_attention( def use_trtllm_attention(
num_tokens: int, num_tokens: int,
max_seq_len: int, max_seq_len: int,
@ -157,9 +182,8 @@ def use_trtllm_attention(
attn_head_size: Optional[int], attn_head_size: Optional[int],
has_sinks: bool = False, has_sinks: bool = False,
) -> bool: ) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins use_trtllm, env_value = supports_trtllm_attention()
if not (current_platform.is_device_capability(100) if not use_trtllm:
and has_nvidia_artifactory()):
return False return False
# Check if the dimensions are supported by TRTLLM decode attention # Check if the dimensions are supported by TRTLLM decode attention
@ -174,18 +198,7 @@ def use_trtllm_attention(
"Using TRTLLM attention (required for attention sinks).") "Using TRTLLM attention (required for attention sinks).")
return True return True
env_value = envs.VLLM_USE_TRTLLM_ATTENTION if env_value is None:
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm
else:
# Environment variable not set - use auto-detection # Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto") and kv_cache_dtype == "auto")
@ -193,6 +206,9 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM attention (auto-detected).") logger.warning_once("Using TRTLLM attention (auto-detected).")
return use_trtllm return use_trtllm
# Environment variable is set to 1 - respect it
return True
if has_flashinfer(): if has_flashinfer():

View File

@ -248,19 +248,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
@functools.lru_cache @functools.lru_cache
def get_kv_cache_layout(): def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE global _KV_CACHE_LAYOUT_OVERRIDE
# Override with format specified by the user.
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
"Setting KV cache layout to %s.", cache_layout)
return cache_layout
# Format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT cache_layout = envs.VLLM_KV_CACHE_LAYOUT
# When neither the user nor the override specified a layout, get default
if cache_layout is None: if cache_layout is None:
if envs.VLLM_USE_TRTLLM_ATTENTION: cache_layout = get_kv_connector_cache_layout()
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
else: else:
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout) "detected. Setting KV cache layout to %s.", cache_layout)
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
return cache_layout return cache_layout