From a543e678b45a08c6bd98a4e5ebcc244679003659 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 12 Nov 2025 21:40:59 -0500 Subject: [PATCH] [Bugfix] Fix SM100 gpt-oss regression due to faulty attn sink support (#28561) Signed-off-by: mgoin --- vllm/utils/flashinfer.py | 31 ++++++++++++++++-------- vllm/v1/attention/backends/flashinfer.py | 15 ++++++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5101020fda12..62af39513d65 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get( ) +@functools.cache +def has_flashinfer_cubin() -> bool: + """Return `True` if flashinfer-cubin package is available.""" + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + if importlib.util.find_spec("flashinfer_cubin") is not None: + return True + logger.debug_once("flashinfer-cubin package was not found") + return False + + @functools.cache def has_flashinfer() -> bool: - """Return `True` if FlashInfer is available.""" + """Return `True` if flashinfer-python package is available.""" # Use find_spec to check if the module exists without importing it # This avoids potential CUDA initialization side effects if importlib.util.find_spec("flashinfer") is None: @@ -45,7 +56,7 @@ def has_flashinfer() -> bool: return False # When not using flashinfer cubin, # Also check if nvcc is available since it's required to JIT compile flashinfer - if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None: + if not has_flashinfer_cubin() and shutil.which("nvcc") is None: logger.debug_once( "FlashInfer unavailable since nvcc was not found " "and not using pre-downloaded cubins" @@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool: This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ - # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when - # it's true, we could assume the cubins are available. - if envs.VLLM_HAS_FLASHINFER_CUBIN: + # If we have pre-downloaded cubins, we can assume the cubins are available. + if has_flashinfer_cubin(): return True try: @@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool: @functools.cache def supports_trtllm_attention() -> bool: """ - TRTLLM attention is supported if the platform is SM100 and - NVIDIA artifactory is accessible + TRTLLM attention is supported if the platform is SM100, + NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. """ + # Batch-invariant mode disables TRTLLM attention + if vllm_is_batch_invariant(): + return False + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins return current_platform.is_device_capability(100) and has_nvidia_artifactory() @@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None: return `True` if TRTLLM attention is forced to be used, return `False` if TRTLLM attention is forced to be not used. """ - if vllm_is_batch_invariant(): - logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant") - return False return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1ce8e6f3d89f..0b650e2e0d33 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend): 12, 1 ) + @classmethod + def supports_sink(cls) -> bool: + """FlashInfer supports sinks when TRTLLM attention is available (SM100).""" + from vllm.utils.flashinfer import ( + force_use_trtllm_attention, + supports_trtllm_attention, + ) + + # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) + if force_use_trtllm_attention() is False: + return False + + # Check if TRTLLM is supported on this platform + return supports_trtllm_attention() + @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: from vllm.platforms import current_platform