[Bugfix] Fix SM100 gpt-oss regression due to faulty attn sink support (#28561)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-11-12 21:40:59 -05:00 committed by GitHub
parent 2dacd57394
commit a543e678b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 10 deletions

View File

@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
) )
@functools.cache
def has_flashinfer_cubin() -> bool:
"""Return `True` if flashinfer-cubin package is available."""
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True
if importlib.util.find_spec("flashinfer_cubin") is not None:
return True
logger.debug_once("flashinfer-cubin package was not found")
return False
@functools.cache @functools.cache
def has_flashinfer() -> bool: def has_flashinfer() -> bool:
"""Return `True` if FlashInfer is available.""" """Return `True` if flashinfer-python package is available."""
# Use find_spec to check if the module exists without importing it # Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects # This avoids potential CUDA initialization side effects
if importlib.util.find_spec("flashinfer") is None: if importlib.util.find_spec("flashinfer") is None:
@ -45,7 +56,7 @@ def has_flashinfer() -> bool:
return False return False
# When not using flashinfer cubin, # When not using flashinfer cubin,
# Also check if nvcc is available since it's required to JIT compile flashinfer # Also check if nvcc is available since it's required to JIT compile flashinfer
if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None: if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
logger.debug_once( logger.debug_once(
"FlashInfer unavailable since nvcc was not found " "FlashInfer unavailable since nvcc was not found "
"and not using pre-downloaded cubins" "and not using pre-downloaded cubins"
@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool:
This checks connectivity to the kernel inference library artifactory This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA. which is required for downloading certain cubin kernels like TRTLLM FHMA.
""" """
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when # If we have pre-downloaded cubins, we can assume the cubins are available.
# it's true, we could assume the cubins are available. if has_flashinfer_cubin():
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True return True
try: try:
@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool:
@functools.cache @functools.cache
def supports_trtllm_attention() -> bool: def supports_trtllm_attention() -> bool:
""" """
TRTLLM attention is supported if the platform is SM100 and TRTLLM attention is supported if the platform is SM100,
NVIDIA artifactory is accessible NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
""" """
# Batch-invariant mode disables TRTLLM attention
if vllm_is_batch_invariant():
return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
return current_platform.is_device_capability(100) and has_nvidia_artifactory() return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used, return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used. return `False` if TRTLLM attention is forced to be not used.
""" """
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)

View File

@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend):
12, 1 12, 1
) )
@classmethod
def supports_sink(cls) -> bool:
"""FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
from vllm.utils.flashinfer import (
force_use_trtllm_attention,
supports_trtllm_attention,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
if force_use_trtllm_attention() is False:
return False
# Check if TRTLLM is supported on this platform
return supports_trtllm_attention()
@classmethod @classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform from vllm.platforms import current_platform