[Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Wentao Ye 2025-08-27 10:09:08 -04:00 committed by GitHub
parent 513c1fe255
commit 3af47c3cc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 68 additions and 53 deletions

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
dg_available = has_deep_gemm()
@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Not E8M0 scale MOE")
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):

View File

@ -20,8 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
@ -374,7 +373,7 @@ NUM_EXPERTS = [32]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
@ -432,7 +431,7 @@ USE_FP8_DISPATCH = [False]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],

View File

@ -131,6 +131,7 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
@ -954,9 +955,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
@ -1244,6 +1248,8 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP8",

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_e8m0_used)
is_deep_gemm_e8m0_used)
logger = init_logger(__name__)
@ -174,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm(
eps,
fp8_min,
fp8_max,
is_blackwell_deep_gemm_e8m0_used(),
is_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,

View File

@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@ -1431,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if (allow_deep_gemm and use_fp8_w8a8
and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm(hidden_states, w1, w2))):
if (allow_deep_gemm and use_fp8_w8a8 and
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")

View File

@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
if self.allow_deep_gemm and (is_deep_gemm_e8m0_used()
or _valid_deep_gemm_shape(M, N, K)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
@ -143,7 +143,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_e8m0_used()))
or is_deep_gemm_e8m0_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None

View File

@ -48,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_moe
if TYPE_CHECKING:
@ -427,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase):
# On B200, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_blackwell_deep_gemm_e8m0_used():
if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
@ -734,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
# Lazy import to avoid CUDA initialization problems.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
@ -871,7 +870,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
if is_blackwell_deep_gemm_e8m0_used():
if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)

View File

@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@ -385,7 +385,7 @@ def per_token_group_quant_fp8(
scaling factor.
"""
if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
use_ue8m0 = is_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "

View File

@ -501,6 +501,24 @@ def get_config(
if quantization_config is not None:
config.quantization_config = quantization_config
# auto-enable DeepGEMM UE8M0 on Hopper if model config requests it
scale_fmt = quantization_config.get("scale_fmt", None)
if scale_fmt in ("ue8m0", ):
if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"):
os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1"
logger.info_once(
("Detected quantization_config.scale_fmt=%s; "
"enabling Hopper UE8M0."),
scale_fmt,
)
elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
logger.warning_once(
("Model config requests UE8M0 "
"(quantization_config.scale_fmt=%s), but "
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; "
"Hopper UE8M0 disabled."),
scale_fmt,
)
if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)

View File

@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool:
@functools.cache
def is_blackwell_deep_gemm_e8m0_used() -> bool:
def is_deep_gemm_e8m0_used() -> bool:
"""Return ``True`` if vLLM is configured to use DeepGEMM "
"E8M0 scale on a Blackwell-class GPU.
"E8M0 scale on a Hopper or Blackwell-class GPU.
"""
if not is_deep_gemm_supported():
logger.debug_once(
logger.info_once(
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
return False
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
return False
_lazy_init()
if _fp8_gemm_nt_impl is None:
logger.debug_once(
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False
enabled = (current_platform.is_cuda()
and current_platform.has_device_capability(100))
if enabled:
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
else:
logger.debug_once(
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
return enabled
if current_platform.is_device_capability(100) and \
envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
return True
if current_platform.is_device_capability(90) and \
envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
return True
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
return False
def _missing(*_: Any, **__: Any) -> NoReturn:
@ -124,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
return _fp8_gemm_nt_impl(*args,
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
**kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
return _grouped_impl(*args,
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
**kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
@ -145,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(
*args,
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
**kwargs)
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
def _ceil_to_ue8m0(x: torch.Tensor):
@ -211,7 +206,7 @@ __all__ = [
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"per_block_cast_to_fp8",
"is_blackwell_deep_gemm_e8m0_used",
"is_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
"should_use_deepgemm_for_fp8_linear",
]