mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 22:45:18 +08:00
[Feature] Add SM103 (Blackwell Ultra) Support to vLLM (#30484)
Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
57e9bf1864
commit
4fa7ce46f3
@ -20,7 +20,7 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
|
||||
@ -32,8 +32,8 @@ def cal_diff(
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||
"Cutlass MLA Requires compute capability of 10 or above."
|
||||
if not current_platform.is_device_capability(100)
|
||||
"Cutlass MLA Requires compute capability of 100 or above."
|
||||
if not current_platform.is_device_capability_family(100)
|
||||
else "Cutlass MLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 1e-1, 2e-1
|
||||
rtol, atol = 3e-1, 4e-1
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
|
||||
@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
|
||||
@ -10,9 +10,9 @@ import pytest
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This test only runs on Blackwell GPUs (SM100).", allow_module_level=True
|
||||
"This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -936,7 +936,7 @@ def enable_batch_invariant_mode():
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
|
||||
@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||
"""
|
||||
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
|
||||
return (
|
||||
is_deep_gemm_e8m0_used()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
|
||||
@ -137,7 +137,7 @@ def get_fp8_moe_backend(
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
@ -148,7 +148,7 @@ def get_fp8_moe_backend(
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability(100):
|
||||
if block_quant and current_platform.is_device_capability_family(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization. Please use "
|
||||
@ -193,7 +193,7 @@ def get_fp8_moe_backend(
|
||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and block_quant
|
||||
):
|
||||
logger.info_once(
|
||||
|
||||
@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
||||
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
elif (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
||||
):
|
||||
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
elif (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
):
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||
elif current_platform.is_device_capability_family(100) and has_flashinfer():
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
||||
"For faster performance on SM100, consider setting "
|
||||
@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
)
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
elif (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
) and not has_flashinfer():
|
||||
logger.warning_once(
|
||||
|
||||
@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -290,7 +290,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
if flashinfer_moe_backend in backend_map:
|
||||
if (
|
||||
flashinfer_moe_backend == "latency"
|
||||
and not current_platform.has_device_capability(100)
|
||||
and not current_platform.is_device_capability_family(100)
|
||||
):
|
||||
logger.info_once(
|
||||
"Flashinfer TRTLLM MOE backend is only supported on "
|
||||
|
||||
@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp:
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.is_blackwell = current_platform.is_device_capability(100)
|
||||
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
|
||||
@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"split_k": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
elif current_platform.is_device_capability(100):
|
||||
elif current_platform.is_device_capability_family(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
|
||||
@ -363,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
current_platform.is_device_capability_family(100)
|
||||
and model_config.get_head_size() == 256
|
||||
and (
|
||||
attention_config.backend is None
|
||||
|
||||
@ -182,7 +182,7 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
if vllm_config.attention_config.backend is None:
|
||||
# Default case
|
||||
if cls.is_device_capability(100) and not use_sparse:
|
||||
if cls.is_device_capability_family(100) and not use_sparse:
|
||||
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
|
||||
use_cutlass_mla = True
|
||||
# Set the backend in AttentionConfig so it's used during
|
||||
|
||||
@ -301,6 +301,21 @@ class Platform:
|
||||
|
||||
return current_capability.to_int() == capability
|
||||
|
||||
@classmethod
|
||||
def is_device_capability_family(
|
||||
cls,
|
||||
capability: int,
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the device capability is any <major>.x.
|
||||
Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
return (current_capability.to_int() // 10) == (capability // 10)
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
|
||||
@ -38,7 +38,7 @@ class DeepGemmQuantScaleFMT(Enum):
|
||||
return DeepGemmQuantScaleFMT.FLOAT32
|
||||
return (
|
||||
DeepGemmQuantScaleFMT.UE8M0
|
||||
if current_platform.is_device_capability(100)
|
||||
if current_platform.is_device_capability_family(100)
|
||||
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
)
|
||||
|
||||
@ -50,7 +50,7 @@ def is_deep_gemm_supported() -> bool:
|
||||
"""
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
)
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
|
||||
|
||||
@ -264,7 +264,9 @@ def supports_trtllm_attention() -> bool:
|
||||
return False
|
||||
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||
return (
|
||||
current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
|
||||
)
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> bool | None:
|
||||
|
||||
@ -564,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
||||
|
||||
if self.head_dim == 256 and current_platform.is_device_capability(100):
|
||||
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
||||
# head size 256 and block size 16 is not supported on blackwell.
|
||||
assert kv_cache_spec.block_size != 16, (
|
||||
|
||||
@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool:
|
||||
and flashinfer_available
|
||||
and not vllm_config.attention_config.use_cudnn_prefill
|
||||
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool:
|
||||
return (
|
||||
flashinfer_available
|
||||
and vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_nvidia_artifactory()
|
||||
)
|
||||
|
||||
@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
return (
|
||||
flashinfer_available
|
||||
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -420,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||
)
|
||||
if current_platform.is_device_capability(100):
|
||||
if current_platform.is_device_capability_family(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
@ -719,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
||||
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user