mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:35:22 +08:00
[Kernel] Simplify get_kv_cache_layout and cache use_trtllm_attention env-dependent bit (#22735)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ad0297d113
commit
070da660c1
@ -148,6 +148,31 @@ def has_nvidia_artifactory() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
|
||||||
|
"""Cache result which only depends on the environment"""
|
||||||
|
# This is a lambda, call it once
|
||||||
|
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||||
|
|
||||||
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
|
if not (current_platform.is_device_capability(100)
|
||||||
|
and has_nvidia_artifactory()):
|
||||||
|
return False, env_value
|
||||||
|
|
||||||
|
if env_value is not None:
|
||||||
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
|
# Environment variable is set - respect it
|
||||||
|
# Making the conditional check for zero because
|
||||||
|
# the path is automatically enabled if the batch size condition
|
||||||
|
# is satisfied.
|
||||||
|
use_trtllm = (env_value == "1")
|
||||||
|
if use_trtllm:
|
||||||
|
logger.info_once("Using TRTLLM attention.")
|
||||||
|
return use_trtllm, env_value
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def use_trtllm_attention(
|
def use_trtllm_attention(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
@ -157,9 +182,8 @@ def use_trtllm_attention(
|
|||||||
attn_head_size: Optional[int],
|
attn_head_size: Optional[int],
|
||||||
has_sinks: bool = False,
|
has_sinks: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
use_trtllm, env_value = supports_trtllm_attention()
|
||||||
if not (current_platform.is_device_capability(100)
|
if not use_trtllm:
|
||||||
and has_nvidia_artifactory()):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if the dimensions are supported by TRTLLM decode attention
|
# Check if the dimensions are supported by TRTLLM decode attention
|
||||||
@ -174,18 +198,7 @@ def use_trtllm_attention(
|
|||||||
"Using TRTLLM attention (required for attention sinks).")
|
"Using TRTLLM attention (required for attention sinks).")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
if env_value is None:
|
||||||
if env_value is not None:
|
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
|
||||||
# Environment variable is set - respect it
|
|
||||||
# Making the conditional check for zero because
|
|
||||||
# the path is automatically enabled if the batch size condition
|
|
||||||
# is satisfied.
|
|
||||||
use_trtllm = (env_value == "1")
|
|
||||||
if use_trtllm:
|
|
||||||
logger.info_once("Using TRTLLM attention.")
|
|
||||||
return use_trtllm
|
|
||||||
else:
|
|
||||||
# Environment variable not set - use auto-detection
|
# Environment variable not set - use auto-detection
|
||||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||||
and kv_cache_dtype == "auto")
|
and kv_cache_dtype == "auto")
|
||||||
@ -193,6 +206,9 @@ def use_trtllm_attention(
|
|||||||
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||||
return use_trtllm
|
return use_trtllm
|
||||||
|
|
||||||
|
# Environment variable is set to 1 - respect it
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
if has_flashinfer():
|
if has_flashinfer():
|
||||||
|
|
||||||
|
|||||||
@ -248,19 +248,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_kv_cache_layout():
|
def get_kv_cache_layout():
|
||||||
|
# Format specified by the code.
|
||||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
# Override with format specified by the user.
|
|
||||||
|
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||||
|
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
|
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
|
||||||
|
"Setting KV cache layout to %s.", cache_layout)
|
||||||
|
return cache_layout
|
||||||
|
|
||||||
|
# Format specified by the user.
|
||||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||||
|
# When neither the user nor the override specified a layout, get default
|
||||||
if cache_layout is None:
|
if cache_layout is None:
|
||||||
if envs.VLLM_USE_TRTLLM_ATTENTION:
|
cache_layout = get_kv_connector_cache_layout()
|
||||||
cache_layout = "HND"
|
|
||||||
else:
|
|
||||||
cache_layout = get_kv_connector_cache_layout()
|
|
||||||
else:
|
else:
|
||||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
|
||||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
|
||||||
return cache_layout
|
return cache_layout
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user