diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 7dc6282326b6..75b2e9f79178 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -224,7 +224,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), + reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 266f1161a684..9b064db973dd 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used, +from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, is_deep_gemm_supported) from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -370,7 +370,7 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), +@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -427,7 +427,7 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), +@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/vllm/envs.py b/vllm/envs.py index c26c7f215dfe..931edcfa7f1e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -127,6 +127,7 @@ if TYPE_CHECKING: VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -925,6 +926,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. + # E8M0 is faster on B200 but may reduce accuracy. + "VLLM_USE_DEEP_GEMM_E8M0": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 3ccddb52998b..c48a0137c306 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_used) + is_blackwell_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -176,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_used(), + is_blackwell_deep_gemm_e8m0_used(), BLOCK=group_size, NUM_STAGES=8, num_warps=1, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 86cc6e0e5dac..ad094c37f947 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1387,8 +1387,8 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm( - hidden_states, w1, w2) + should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used( + ) or _valid_deep_gemm(hidden_states, w1, w2) if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): assert apply_router_weight_on_input is False assert is_act_and_mul, ( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index c67f7e808301..9d0ff2e06190 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_used() + if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -133,7 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): extra_expert_args: Optional[dict[str, Any]]): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_used())) + or is_blackwell_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8b6ed154bdbe..9577fa025b70 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -45,7 +45,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, + is_deep_gemm_supported) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -415,10 +416,10 @@ class Fp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale - # On B200, DeepGemm only support E8M0 scale, which means we need to + # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_used(): + if is_blackwell_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -505,15 +506,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " "DeepGemm kernels") - elif (current_platform.is_cuda() - and current_platform.is_device_capability(90)): + elif (is_deep_gemm_supported()): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True - elif (current_platform.is_cuda() - and is_blackwell_deep_gemm_used()): - logger.info_once("Using DeepGemm SM100 kernels for " - "Fp8MoEMethod.") - self.allow_deep_gemm = True else: logger.warning_once( "DeepGemm not supported on the current platform.") @@ -725,7 +720,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): + if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -851,7 +846,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_used(): + if is_blackwell_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 68a061968aa9..2fb7ef29e468 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used logger = init_logger(__name__) @@ -394,10 +394,8 @@ def per_token_group_quant_fp8( tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor. """ - # TODO(wentao): refactor this - # use_ue8m0 should be a global flag that could be set by user if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_used() + use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 174287b44b76..861d9c0c0005 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -31,19 +31,37 @@ def is_deep_gemm_supported() -> bool: @functools.cache -def is_blackwell_deep_gemm_used() -> bool: - """Return ``True`` if vLLM is configured to use DeepGEMM on a - Blackwell-class GPU. +def is_blackwell_deep_gemm_e8m0_used() -> bool: + """Return ``True`` if vLLM is configured to use DeepGEMM " + "E8M0 scale on a Blackwell-class GPU. """ - if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()): + if not (envs.VLLM_USE_DEEP_GEMM): + logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.") + return False + + if not has_deep_gemm(): + logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.") + return False + + if not envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") return False _lazy_init() + if _fp8_gemm_nt_impl is None: + logger.debug_once( + "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - return (current_platform.is_cuda() - and current_platform.is_device_capability(100)) + enabled = (current_platform.is_cuda() + and current_platform.has_device_capability(100)) + if enabled: + logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + else: + logger.debug_once( + "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") + return enabled def _missing(*_: Any, **__: Any) -> NoReturn: @@ -109,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl(*args, **kwargs) + return _fp8_gemm_nt_impl( + *args, + disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl(*args, **kwargs) + return _grouped_impl( + *args, + disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): _lazy_init() if _grouped_masked_impl is None: return _missing(*args, **kwargs) - return _grouped_masked_impl(*args, **kwargs) + return _grouped_masked_impl( + *args, + disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), + **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -181,6 +208,6 @@ __all__ = [ "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_used", + "is_blackwell_deep_gemm_e8m0_used", "is_deep_gemm_supported", ]