[Bugfix] Fix SM100 gpt-oss regression due to faulty attn sink support (#28561)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-11-12 21:40:59 -05:00 committed by GitHub
parent 2dacd57394
commit a543e678b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 10 deletions

View File

@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
)
@functools.cache
def has_flashinfer_cubin() -> bool:
"""Return `True` if flashinfer-cubin package is available."""
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True
if importlib.util.find_spec("flashinfer_cubin") is not None:
return True
logger.debug_once("flashinfer-cubin package was not found")
return False
@functools.cache
def has_flashinfer() -> bool:
"""Return `True` if FlashInfer is available."""
"""Return `True` if flashinfer-python package is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
if importlib.util.find_spec("flashinfer") is None:
@ -45,7 +56,7 @@ def has_flashinfer() -> bool:
return False
# When not using flashinfer cubin,
# Also check if nvcc is available since it's required to JIT compile flashinfer
if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None:
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
logger.debug_once(
"FlashInfer unavailable since nvcc was not found "
"and not using pre-downloaded cubins"
@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool:
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
# it's true, we could assume the cubins are available.
if envs.VLLM_HAS_FLASHINFER_CUBIN:
# If we have pre-downloaded cubins, we can assume the cubins are available.
if has_flashinfer_cubin():
return True
try:
@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool:
@functools.cache
def supports_trtllm_attention() -> bool:
"""
TRTLLM attention is supported if the platform is SM100 and
NVIDIA artifactory is accessible
TRTLLM attention is supported if the platform is SM100,
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
"""
# Batch-invariant mode disables TRTLLM attention
if vllm_is_batch_invariant():
return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)

View File

@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend):
12, 1
)
@classmethod
def supports_sink(cls) -> bool:
"""FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
from vllm.utils.flashinfer import (
force_use_trtllm_attention,
supports_trtllm_attention,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
if force_use_trtllm_attention() is False:
return False
# Check if TRTLLM is supported on this platform
return supports_trtllm_attention()
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform