From f81c1bb05504672ddd66905161c6ada549fd4b85 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 1 Aug 2025 08:28:45 -0400 Subject: [PATCH] [Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893) --- vllm/attention/backends/flashinfer.py | 46 +------------- vllm/utils/flashinfer.py | 81 +++++++++++++++++++++++- vllm/v1/attention/backends/flashinfer.py | 49 +------------- vllm/v1/attention/backends/mla/common.py | 16 ++--- 4 files changed, 93 insertions(+), 99 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 824ff8cca201a..b3372ce2eca8c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -44,9 +44,9 @@ from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) +from vllm.utils.flashinfer import use_trtllm_decode_attention logger = init_logger(__name__) @@ -56,7 +56,6 @@ if TYPE_CHECKING: class FlashInferBackend(AttentionBackend): - cached_sm100a_supported: Optional[bool] = None @staticmethod def get_name() -> str: @@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend): else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - @staticmethod - def use_trtllm_decode_attention( - batch_size: int, - max_seq_len: int, - kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], - ) -> bool: - if FlashInferBackend.cached_sm100a_supported is None: - FlashInferBackend.cached_sm100a_supported = ( - current_platform.has_device_capability(100)) - if not FlashInferBackend.cached_sm100a_supported: - return False - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None - or num_kv_heads is None or num_qo_heads // num_kv_heads > 8 - or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): - return False - env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION - if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", - env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - no_use_trtllm = (env_value == "0") - if not no_use_trtllm: - logger.info_once("Using TRTLLM decode attention.") - return not no_use_trtllm - else: - # Environment variable not set - use auto-detection - use_trtllm = (FlashInferBackend.cached_sm100a_supported - and batch_size <= 256 and max_seq_len < 131072 - and kv_cache_dtype == "auto") - if use_trtllm: - logger.warning_once( - "Using TRTLLM decode attention (auto-detected).") - return use_trtllm - @dataclass class PerLayerParameters: @@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl): assert decode_meta.decode_wrapper._sm_scale == softmax_scale # TODO: @pavanimajety Remove this once the switch happens # inside flashinfer. - if not FlashInferBackend.use_trtllm_decode_attention( + if not use_trtllm_decode_attention( num_decode_tokens, attn_metadata.max_decode_seq_len, kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 3bfb9808c0a00..29967bc516715 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -10,12 +10,25 @@ import contextlib import functools import importlib import importlib.util -from typing import Any, Callable, NoReturn +import os +from typing import Any, Callable, NoReturn, Optional +import requests + +import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) +# This is the storage path for the cubins, it can be replaced +# with a local path for testing. +# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501 +FLASHINFER_CUBINS_REPOSITORY = os.environ.get( + "FLASHINFER_CUBINS_REPOSITORY", + "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501 +) + @functools.cache def has_flashinfer() -> bool: @@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool: return True +@functools.cache +def has_nvidia_artifactory() -> bool: + """Return ``True`` if NVIDIA's artifactory is accessible. + + This checks connectivity to the kernel inference library artifactory + which is required for downloading certain cubin kernels like TRTLLM FHMA. + """ + try: + # Use a short timeout to avoid blocking for too long + response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) + accessible = response.status_code == 200 + if accessible: + logger.debug_once("NVIDIA artifactory is accessible") + else: + logger.warning_once( + "NVIDIA artifactory returned failed status code: %d", + response.status_code) + return accessible + except Exception as e: + logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e) + return False + + +def use_trtllm_decode_attention( + num_tokens: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: Optional[int], + num_kv_heads: Optional[int], + attn_head_size: Optional[int], +) -> bool: + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins + if not (current_platform.is_device_capability(100) + and has_nvidia_artifactory()): + return False + + # Check if the dimensions are supported by TRTLLM decode attention + if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None + or num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + return False + + env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = (env_value == "0") + if not no_use_trtllm: + logger.info_once("Using TRTLLM decode attention.") + return not no_use_trtllm + else: + # Environment variable not set - use auto-detection + use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 + and kv_cache_dtype == "auto") + if use_trtllm: + logger.warning_once( + "Using TRTLLM decode attention (auto-detected).") + return use_trtllm + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -117,4 +194,6 @@ __all__ = [ "autotune", "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", + "has_nvidia_artifactory", + "use_trtllm_decode_attention", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 27552f0e7c1ef..f8af1d7e41831 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.utils.flashinfer import use_trtllm_decode_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, @@ -38,7 +38,6 @@ logger = init_logger(__name__) class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - cached_sm100a_supported: Optional[bool] = None @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend): raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order - @staticmethod - def use_trtllm_decode_attention( - batch_size: int, - max_seq_len: int, - kv_cache_dtype: str, - num_qo_heads: int, - num_kv_heads: int, - attn_head_size: int, - ) -> bool: - if FlashInferBackend.cached_sm100a_supported is None: - FlashInferBackend.cached_sm100a_supported = ( - current_platform.has_device_capability(100)) - if not FlashInferBackend.cached_sm100a_supported: - return False - if (num_qo_heads // num_kv_heads > 8 - or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): - return False - env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION - if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", - env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - no_use_trtllm = env_value == "0" - if not no_use_trtllm: - logger.info_once( - "VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, " - "using TRTLLM decode attention.") - return not no_use_trtllm - else: - # Environment variable not set - use auto-detection - # Only supports attention head size of 128 - use_trtllm = (FlashInferBackend.cached_sm100a_supported - and batch_size <= 256 and max_seq_len < 131072 - and kv_cache_dtype == "auto") - if use_trtllm: - logger.warning_once( - "Using TRTLLM decode attention (auto-detected).") - return use_trtllm - @staticmethod def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: if kv_cache_dtype in ("fp8", "fp8_e4m3"): @@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if num_decodes > 0: attn_metadata.decode_wrapper = self._get_decode_wrapper() - if not FlashInferBackend.use_trtllm_decode_attention( + if not use_trtllm_decode_attention( num_decodes, attn_metadata.max_seq_len, self.cache_config.cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, @@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl): decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens assert decode_wrapper is not None - if not FlashInferBackend.use_trtllm_decode_attention( + if not use_trtllm_decode_attention( attn_metadata.num_decodes, attn_metadata.max_seq_len, self.kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0095d75217856..d112468f1c91d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, UnquantizedLinearMethod) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_per_layer_parameters, infer_global_hyperparameters, @@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata) def use_flashinfer_prefill() -> bool: - if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL: - # For blackwell default to flashinfer prefill if its available since - # its faster than FA2. - return current_platform.has_device_capability(100) - return False + # For blackwell default to flashinfer prefill if its available since + # it is faster than FA2. + return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100)) def use_cudnn_prefill() -> bool: - if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL: - return current_platform.has_device_capability(100) - return False + return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory()) # Currently 394MB, this can be tuned based on GEMM sizes used.