mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:25:23 +08:00
[Bugfix] Fix SM100 gpt-oss regression due to faulty attn sink support (#28561)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
2dacd57394
commit
a543e678b4
@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_cubin() -> bool:
|
||||||
|
"""Return `True` if flashinfer-cubin package is available."""
|
||||||
|
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||||
|
return True
|
||||||
|
if importlib.util.find_spec("flashinfer_cubin") is not None:
|
||||||
|
return True
|
||||||
|
logger.debug_once("flashinfer-cubin package was not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def has_flashinfer() -> bool:
|
def has_flashinfer() -> bool:
|
||||||
"""Return `True` if FlashInfer is available."""
|
"""Return `True` if flashinfer-python package is available."""
|
||||||
# Use find_spec to check if the module exists without importing it
|
# Use find_spec to check if the module exists without importing it
|
||||||
# This avoids potential CUDA initialization side effects
|
# This avoids potential CUDA initialization side effects
|
||||||
if importlib.util.find_spec("flashinfer") is None:
|
if importlib.util.find_spec("flashinfer") is None:
|
||||||
@ -45,7 +56,7 @@ def has_flashinfer() -> bool:
|
|||||||
return False
|
return False
|
||||||
# When not using flashinfer cubin,
|
# When not using flashinfer cubin,
|
||||||
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
||||||
if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None:
|
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
|
||||||
logger.debug_once(
|
logger.debug_once(
|
||||||
"FlashInfer unavailable since nvcc was not found "
|
"FlashInfer unavailable since nvcc was not found "
|
||||||
"and not using pre-downloaded cubins"
|
"and not using pre-downloaded cubins"
|
||||||
@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool:
|
|||||||
This checks connectivity to the kernel inference library artifactory
|
This checks connectivity to the kernel inference library artifactory
|
||||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||||
"""
|
"""
|
||||||
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
|
# If we have pre-downloaded cubins, we can assume the cubins are available.
|
||||||
# it's true, we could assume the cubins are available.
|
if has_flashinfer_cubin():
|
||||||
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool:
|
|||||||
@functools.cache
|
@functools.cache
|
||||||
def supports_trtllm_attention() -> bool:
|
def supports_trtllm_attention() -> bool:
|
||||||
"""
|
"""
|
||||||
TRTLLM attention is supported if the platform is SM100 and
|
TRTLLM attention is supported if the platform is SM100,
|
||||||
NVIDIA artifactory is accessible
|
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
||||||
"""
|
"""
|
||||||
|
# Batch-invariant mode disables TRTLLM attention
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
return False
|
||||||
|
|
||||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||||
|
|
||||||
@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None:
|
|||||||
return `True` if TRTLLM attention is forced to be used,
|
return `True` if TRTLLM attention is forced to be used,
|
||||||
return `False` if TRTLLM attention is forced to be not used.
|
return `False` if TRTLLM attention is forced to be not used.
|
||||||
"""
|
"""
|
||||||
if vllm_is_batch_invariant():
|
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
|
|
||||||
return False
|
|
||||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
12, 1
|
12, 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_sink(cls) -> bool:
|
||||||
|
"""FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
|
||||||
|
from vllm.utils.flashinfer import (
|
||||||
|
force_use_trtllm_attention,
|
||||||
|
supports_trtllm_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
|
||||||
|
if force_use_trtllm_attention() is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if TRTLLM is supported on this platform
|
||||||
|
return supports_trtllm_attention()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user