mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:45:33 +08:00
[Kernel] Simplify get_kv_cache_layout and cache use_trtllm_attention env-dependent bit (#22735)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ad0297d113
commit
070da660c1
@ -148,6 +148,31 @@ def has_nvidia_artifactory() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
|
||||
"""Cache result which only depends on the environment"""
|
||||
# This is a lambda, call it once
|
||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
if not (current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory()):
|
||||
return False, env_value
|
||||
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
use_trtllm = (env_value == "1")
|
||||
if use_trtllm:
|
||||
logger.info_once("Using TRTLLM attention.")
|
||||
return use_trtllm, env_value
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
@ -157,9 +182,8 @@ def use_trtllm_attention(
|
||||
attn_head_size: Optional[int],
|
||||
has_sinks: bool = False,
|
||||
) -> bool:
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
if not (current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory()):
|
||||
use_trtllm, env_value = supports_trtllm_attention()
|
||||
if not use_trtllm:
|
||||
return False
|
||||
|
||||
# Check if the dimensions are supported by TRTLLM decode attention
|
||||
@ -174,18 +198,7 @@ def use_trtllm_attention(
|
||||
"Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
use_trtllm = (env_value == "1")
|
||||
if use_trtllm:
|
||||
logger.info_once("Using TRTLLM attention.")
|
||||
return use_trtllm
|
||||
else:
|
||||
if env_value is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
@ -193,6 +206,9 @@ def use_trtllm_attention(
|
||||
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
return True
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
|
||||
@ -248,19 +248,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
# Override with format specified by the user.
|
||||
|
||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
|
||||
"Setting KV cache layout to %s.", cache_layout)
|
||||
return cache_layout
|
||||
|
||||
# Format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
# When neither the user nor the override specified a layout, get default
|
||||
if cache_layout is None:
|
||||
if envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
cache_layout = "HND"
|
||||
else:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||
return cache_layout
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user