[Feature] Add env var VLLM_MOE_USE_DEEP_GEMM (#28422)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-10 21:29:48 -05:00 committed by GitHub
parent 39029d5192
commit de540c0354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 2 deletions

View File

@ -147,6 +147,7 @@ if TYPE_CHECKING:
VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_MOST_MODEL_LEN: int | None = None
VLLM_TPU_USING_PATHWAYS: bool = False VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM: bool = True
VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_DEEP_GEMM_WARMUP: Literal[ VLLM_DEEP_GEMM_WARMUP: Literal[
"skip", "skip",
@ -1116,6 +1117,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Allow use of DeepGemm kernels for fused moe ops. # Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))),
# Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE).
"VLLM_MOE_USE_DEEP_GEMM": lambda: bool(
int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1"))
),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
"VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))
@ -1569,6 +1574,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_SAMPLER", "VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_DISABLED_KERNELS", "VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM",
"VLLM_MOE_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP16",

View File

@ -966,10 +966,18 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=(
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
),
) )
else: else:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) return TritonOrDeepGemmExperts(
self.moe_quant_config,
allow_deep_gemm=(
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
),
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module

View File

@ -158,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
return Fp8MoeBackend.MARLIN return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights # deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and block_quant: if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm(): if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.") logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported(): elif is_deep_gemm_supported():

View File

@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM):
return False
if not isinstance(module, FusedMoE): if not isinstance(module, FusedMoE):
return False return False