[Feature] Add VLLM_USE_DEEP_GEMM_E8M0 Env to Control E8M0 Scale (#21968)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-08-11 12:39:08 -04:00 committed by GitHub
parent 8e13d9fe6d
commit f7dcce7a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 65 additions and 39 deletions

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe) fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
dg_available = has_deep_gemm() dg_available = has_deep_gemm()
@ -224,7 +224,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") @pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Not E8M0 scale MOE")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch): monkeypatch):

View File

@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used, from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported) is_deep_gemm_supported)
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
@ -370,7 +370,7 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), @pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM") reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]): topk: int, world_dp_size: tuple[int, int]):
@ -427,7 +427,7 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), @pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM") reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe( def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],

View File

@ -127,6 +127,7 @@ if TYPE_CHECKING:
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
@ -925,6 +926,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no # JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine # JIT'ing in the hot-path. However, this warmup increases the engine

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_used) is_blackwell_deep_gemm_e8m0_used)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -176,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
eps, eps,
fp8_min, fp8_min,
fp8_max, fp8_max,
is_blackwell_deep_gemm_used(), is_blackwell_deep_gemm_e8m0_used(),
BLOCK=group_size, BLOCK=group_size,
NUM_STAGES=8, NUM_STAGES=8,
num_warps=1, num_warps=1,

View File

@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@ -1387,8 +1387,8 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific # E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause # scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue. # accuracy issue.
should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm( should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
hidden_states, w1, w2) ) or _valid_deep_gemm(hidden_states, w1, w2)
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False assert apply_router_weight_on_input is False
assert is_act_and_mul, ( assert is_act_and_mul, (

View File

@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
deep_gemm_block_shape) deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and (is_blackwell_deep_gemm_used() if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm_shape(M, N, K)): or _valid_deep_gemm_shape(M, N, K)):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
@ -133,7 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
extra_expert_args: Optional[dict[str, Any]]): extra_expert_args: Optional[dict[str, Any]]):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2) and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_used())) or is_blackwell_deep_gemm_e8m0_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None assert experts is not None

View File

@ -45,7 +45,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.flashinfer import has_flashinfer_moe
if TYPE_CHECKING: if TYPE_CHECKING:
@ -415,10 +416,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
# On B200, DeepGemm only support E8M0 scale, which means we need to # On B200, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale # requantize the weight and input to the specific scale
# at the same time. # at the same time.
if is_blackwell_deep_gemm_used(): if is_blackwell_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size) block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace( requant_weight_ue8m0_inplace(
@ -505,15 +506,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif not self.block_quant: elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.warning_once("Model is not block quantized. Not using "
"DeepGemm kernels") "DeepGemm kernels")
elif (current_platform.is_cuda() elif (is_deep_gemm_supported()):
and current_platform.is_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True self.allow_deep_gemm = True
elif (current_platform.is_cuda()
and is_blackwell_deep_gemm_used()):
logger.info_once("Using DeepGemm SM100 kernels for "
"Fp8MoEMethod.")
self.allow_deep_gemm = True
else: else:
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
@ -725,7 +720,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do # DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons. # it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
# Lazy import to avoid CUDA initialization problems. # Lazy import to avoid CUDA initialization problems.
if _is_col_major(layer.w13_weight_scale_inv): if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \ layer.w13_weight_scale_inv = \
@ -851,7 +846,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
if is_blackwell_deep_gemm_used(): if is_blackwell_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0. # Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size) block_sz = tuple(layer.weight_block_size)

View File

@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
logger = init_logger(__name__) logger = init_logger(__name__)
@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor. scaling factor.
""" """
# TODO(wentao): refactor this
# use_ue8m0 should be a global flag that could be set by user
if use_ue8m0 is None: if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_used() use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), ( assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"the last dimension of `x` {x.shape[-1]} must be divisible "

View File

@ -31,19 +31,37 @@ def is_deep_gemm_supported() -> bool:
@functools.cache @functools.cache
def is_blackwell_deep_gemm_used() -> bool: def is_blackwell_deep_gemm_e8m0_used() -> bool:
"""Return ``True`` if vLLM is configured to use DeepGEMM on a """Return ``True`` if vLLM is configured to use DeepGEMM "
Blackwell-class GPU. "E8M0 scale on a Blackwell-class GPU.
""" """
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()): if not (envs.VLLM_USE_DEEP_GEMM):
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.")
return False
if not has_deep_gemm():
logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.")
return False
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
return False return False
_lazy_init() _lazy_init()
if _fp8_gemm_nt_impl is None: if _fp8_gemm_nt_impl is None:
logger.debug_once(
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False return False
return (current_platform.is_cuda() enabled = (current_platform.is_cuda()
and current_platform.is_device_capability(100)) and current_platform.has_device_capability(100))
if enabled:
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
else:
logger.debug_once(
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
return enabled
def _missing(*_: Any, **__: Any) -> NoReturn: def _missing(*_: Any, **__: Any) -> NoReturn:
@ -109,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs):
_lazy_init() _lazy_init()
if _fp8_gemm_nt_impl is None: if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs) return _fp8_gemm_nt_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init() _lazy_init()
if _grouped_impl is None: if _grouped_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs) return _grouped_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init() _lazy_init()
if _grouped_masked_impl is None: if _grouped_masked_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs) return _grouped_masked_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
def _ceil_to_ue8m0(x: torch.Tensor): def _ceil_to_ue8m0(x: torch.Tensor):
@ -181,6 +208,6 @@ __all__ = [
"m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked", "fp8_m_grouped_gemm_nt_masked",
"per_block_cast_to_fp8", "per_block_cast_to_fp8",
"is_blackwell_deep_gemm_used", "is_blackwell_deep_gemm_e8m0_used",
"is_deep_gemm_supported", "is_deep_gemm_supported",
] ]