From 4fa7ce46f31cbd97b4651694caf9991cc395a259 Mon Sep 17 00:00:00 2001 From: "Roberto L. Castro" <38211239+LopezCastroRoberto@users.noreply.github.com> Date: Sat, 13 Dec 2025 04:34:23 +0100 Subject: [PATCH] [Feature] Add SM103 (Blackwell Ultra) Support to vLLM (#30484) Signed-off-by: LopezCastroRoberto Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: youkaichao --- tests/compile/distributed/test_fusions_e2e.py | 2 +- .../kernels/attention/test_cutlass_mla_decode.py | 4 ++-- .../attention/test_flashinfer_trtllm_attention.py | 4 ++-- tests/kernels/moe/test_ocp_mx_moe.py | 4 ++-- tests/quantization/test_blackwell_moe.py | 4 ++-- vllm/model_executor/layers/batch_invariant.py | 2 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 5 ++++- vllm/model_executor/layers/quantization/fp8.py | 6 +++--- vllm/model_executor/layers/quantization/mxfp4.py | 8 ++++---- .../quantization/utils/flashinfer_fp4_moe.py | 2 +- .../layers/quantization/utils/flashinfer_utils.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 2 +- .../layers/quantization/utils/mxfp4_utils.py | 2 +- vllm/model_executor/models/config.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/interface.py | 15 +++++++++++++++ vllm/utils/deep_gemm.py | 4 ++-- vllm/utils/flashinfer.py | 4 +++- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/mla/common.py | 6 +++--- vllm/v1/attention/backends/mla/flashmla_sparse.py | 4 ++-- 21 files changed, 53 insertions(+), 33 deletions(-) diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 5379b5157b811..1fcafe1840cd3 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -20,7 +20,7 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer from ...utils import flat_product, multi_gpu_test -is_blackwell = lambda: current_platform.is_device_capability(100) +is_blackwell = lambda: current_platform.is_device_capability_family(100) """Are we running on Blackwell, a lot of tests depend on it""" diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index a60f4e385a893..784c16304a286 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -32,8 +32,8 @@ def cal_diff( CUTLASS_MLA_UNSUPPORTED_REASON = ( - "Cutlass MLA Requires compute capability of 10 or above." - if not current_platform.is_device_capability(100) + "Cutlass MLA Requires compute capability of 100 or above." + if not current_platform.is_device_capability_family(100) else "Cutlass MLA is supported" ) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 98ea40608b468..06a7085a82ba0 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True ) @@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 1e-1, 2e-1 + rtol, atol = 3e-1, 4e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 4e-2, 6e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index 5a850dda4f6fd..8fe471d124f43 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( ) >= version.parse("0.8.99") TRTLLM_GEN_MXFP4_AVAILABLE = ( - current_platform.is_cuda() and current_platform.is_device_capability(100) + current_platform.is_cuda() and current_platform.is_device_capability_family(100) ) HOPPER_MXFP4_BF16_AVAILABLE = ( @@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( @pytest.mark.skipif( not ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and has_flashinfer() ), reason="NVIDIA GPU sm100 and flashinfer are required for this test", diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 8dd4551ff4b96..a43d2abfdd8b8 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -10,9 +10,9 @@ import pytest from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( - "This test only runs on Blackwell GPUs (SM100).", allow_module_level=True + "This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True ) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index b14e7dad77f9a..4f31e5afa1ac9 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -936,7 +936,7 @@ def enable_batch_invariant_mode(): # Batch invariant matmuls are no longer needed after cublas overrides if not is_torch_equal_or_newer("2.10.0.dev"): if ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(80) or current_platform.is_device_capability(89) ): diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 53362277dae8a..15f6e3a18ed6c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 """ - return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) + return ( + is_deep_gemm_e8m0_used() + and current_platform.is_device_capability_family(100) + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 60dde9eb57e0f..6909bac1efc7c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -137,7 +137,7 @@ def get_fp8_moe_backend( if ( current_platform.is_cuda() and ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 @@ -148,7 +148,7 @@ def get_fp8_moe_backend( logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - if block_quant and current_platform.is_device_capability(100): + if block_quant and current_platform.is_device_capability_family(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " @@ -193,7 +193,7 @@ def get_fp8_moe_backend( # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights if ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and block_quant ): logger.info_once( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6eae4e9e66e1b..e96e87d15787d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - elif current_platform.is_device_capability(100) and has_flashinfer(): + elif current_platform.is_device_capability_family(100) and has_flashinfer(): logger.info_once( "Using FlashInfer MXFP4 BF16 backend for SM100, " "For faster performance on SM100, consider setting " @@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and not has_flashinfer(): logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 8f96222f19f20..e424cd0e1ac99 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool: envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutedsl_grouped_gemm_nt_masked() and current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index ba3653e4b5ea7..09d0fe6a2f3ad 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -290,7 +290,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: if flashinfer_moe_backend in backend_map: if ( flashinfer_moe_backend == "latency" - and not current_platform.has_device_capability(100) + and not current_platform.is_device_capability_family(100) ): logger.info_once( "Flashinfer TRTLLM MOE backend is only supported on " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9eeb6e266c34e..ea68745585160 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp: self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) - self.is_blackwell = current_platform.is_device_capability(100) + self.is_blackwell = current_platform.is_device_capability_family(100) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() # Get the correct blockscale mul and input quant operations. diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 7a351afb3c415..e9ecf0547033d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "split_k": 1, } opt_flags.update_opt_flags_constraints(constraints) - elif current_platform.is_device_capability(100): + elif current_platform.is_device_capability_family(100): constraints = { "is_persistent": True, "epilogue_subtile": 1, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 06cc92ee88180..4b08472538db4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -363,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): else: kernel_block_alignment_size = 16 if ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and model_config.get_head_size() == 256 and ( attention_config.backend is None diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 915392a4125f9..ef33e64bbfdf4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -182,7 +182,7 @@ class CudaPlatformBase(Platform): if vllm_config.attention_config.backend is None: # Default case - if cls.is_device_capability(100) and not use_sparse: + if cls.is_device_capability_family(100) and not use_sparse: # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). use_cutlass_mla = True # Set the backend in AttentionConfig so it's used during diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f04e94e425257..49437c7d56d12 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -301,6 +301,21 @@ class Platform: return current_capability.to_int() == capability + @classmethod + def is_device_capability_family( + cls, + capability: int, + device_id: int = 0, + ) -> bool: + """ + Returns True if the device capability is any .x. + Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x). + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + return (current_capability.to_int() // 10) == (capability // 10) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index a099fde1bdc45..46be3e2cd5c54 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -38,7 +38,7 @@ class DeepGemmQuantScaleFMT(Enum): return DeepGemmQuantScaleFMT.FLOAT32 return ( DeepGemmQuantScaleFMT.UE8M0 - if current_platform.is_device_capability(100) + if current_platform.is_device_capability_family(100) else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 ) @@ -50,7 +50,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100) + or current_platform.is_device_capability_family(100) ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 9a66049350cd8..5019b771f4a14 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -264,7 +264,9 @@ def supports_trtllm_attention() -> bool: return False # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability(100) and has_nvidia_artifactory() + return ( + current_platform.is_device_capability_family(100) and has_nvidia_artifactory() + ) def force_use_trtllm_attention() -> bool | None: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4174b80ee312e..2740a6916fd97 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -564,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() - if self.head_dim == 256 and current_platform.is_device_capability(100): + if self.head_dim == 256 and current_platform.is_device_capability_family(100): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, ( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8265503c28c35..fea482493635f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool: and flashinfer_available and not vllm_config.attention_config.use_cudnn_prefill and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) @@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool: return ( flashinfer_available and vllm_config.attention_config.use_cudnn_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and has_nvidia_artifactory() ) @@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: return ( flashinfer_available and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index f3052fbaf2a65..0818078da0364 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -420,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) ) - if current_platform.is_device_capability(100): + if current_platform.is_device_capability_family(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( # TileSchedulerMetaDataSize = 8 @@ -719,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability(100) else 64 + self.padding = 128 if current_platform.is_device_capability_family(100) else 64 if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization