From 3af47c3cc693f432b59658019891393385aa0e2a Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 27 Aug 2025 10:09:08 -0400 Subject: [PATCH] [Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666) Signed-off-by: yewentao256 Signed-off-by: youkaichao Co-authored-by: youkaichao --- tests/kernels/moe/test_block_fp8.py | 5 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 7 ++- vllm/envs.py | 8 ++- .../layers/fused_moe/batched_deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 7 ++- .../layers/fused_moe/triton_deep_gemm_moe.py | 6 +-- .../model_executor/layers/quantization/fp8.py | 9 ++-- .../layers/quantization/utils/fp8_utils.py | 4 +- vllm/transformers_utils/config.py | 18 +++++++ vllm/utils/deep_gemm.py | 53 +++++++++---------- 10 files changed, 68 insertions(+), 53 deletions(-) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 9e4eaf221f24..ecc57acc6796 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 1e922be47f2b..36a98522a658 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,8 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -374,7 +373,7 @@ NUM_EXPERTS = [32] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -432,7 +431,7 @@ USE_FP8_DISPATCH = [False] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/vllm/envs.py b/vllm/envs.py index 66c7c2c7f2c4..35735b552575 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -131,6 +131,7 @@ if TYPE_CHECKING: VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False @@ -954,9 +955,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - # E8M0 is faster on B200 but may reduce accuracy. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + # TODO(wentao): unify the two E8M0 flags after verifying the correctness. + # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1244,6 +1248,8 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP8", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c4d680af932f..a5326dfe84f6 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_e8m0_used) + is_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -174,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_e8m0_used(), + is_deep_gemm_e8m0_used(), BLOCK=group_size, NUM_STAGES=4, num_warps=1, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 84dafcf00d82..17a5c735a57f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1431,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 - and (is_blackwell_deep_gemm_e8m0_used() - or _valid_deep_gemm(hidden_states, w1, w2))): + if (allow_deep_gemm and use_fp8_w8a8 and + (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 486ca881df48..6cd81d97f029 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() + if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -143,7 +143,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_e8m0_used())) + or is_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d45d368b582d..be358cfa949f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -48,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -427,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase): # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -734,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -871,7 +870,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ab1d5383f465..7b324dce3c36 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -385,7 +385,7 @@ def per_token_group_quant_fp8( scaling factor. """ if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2cd799e5eb5a..bec792465bfb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -501,6 +501,24 @@ def get_config( if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0", ): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" + logger.info_once( + ("Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0."), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.warning_once( + ("Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled."), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index b0bc3a79eb0a..cd1dbfb813fe 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool: @functools.cache -def is_blackwell_deep_gemm_e8m0_used() -> bool: +def is_deep_gemm_e8m0_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM " - "E8M0 scale on a Blackwell-class GPU. + "E8M0 scale on a Hopper or Blackwell-class GPU. """ if not is_deep_gemm_supported(): - logger.debug_once( + logger.info_once( "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") return False - if not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") - return False - _lazy_init() if _fp8_gemm_nt_impl is None: - logger.debug_once( - "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - enabled = (current_platform.is_cuda() - and current_platform.has_device_capability(100)) - if enabled: - logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") - else: - logger.debug_once( - "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") - return enabled + if current_platform.is_device_capability(100) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + return True + + if current_platform.is_device_capability(90) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: @@ -124,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _fp8_gemm_nt_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -145,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -211,7 +206,7 @@ __all__ = [ "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_e8m0_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", "should_use_deepgemm_for_fp8_linear", ]