From a65f46be5ea9a92dde48df2b951c1915aa1d9595 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 2 Aug 2025 08:12:03 +0530 Subject: [PATCH] [Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- vllm/envs.py | 9 +++ .../layers/fused_moe/deep_gemm_moe.py | 77 ++++++++++++++++++- vllm/utils/deep_gemm.py | 7 ++ 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index c161fa0dff6ba..2d470c6dccbfd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -126,6 +126,7 @@ if TYPE_CHECKING: VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False + VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 @@ -910,6 +911,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm + # JIT all the required kernels before model execution so there is no + # JIT'ing in the hot-path. However, this warmup increases the engine + # startup time by a couple of minutes. + # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. + "VLLM_SKIP_DEEP_GEMM_WARMUP": + lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b89e5ac6f093e..bd3605378b6dc 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -4,7 +4,9 @@ import functools from typing import Any, Optional import torch +from tqdm import tqdm +import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) -from vllm.utils import has_deep_gemm +from vllm.utils import has_deep_gemm, run_once from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) @@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, return True +@run_once +def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int): + """ + DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the + input tensor shapes. In this function, we construct all possible input + tensor shapes so all the kernels are JIT'ed and cached. + Note that this warmup is expected to happen during the model profile + call and not during actual model inference. + """ + + assert w1.size(0) == w2.size(0), ( + "w1 and w2 must have the same number of experts") + + block_m = deep_gemm_block_shape()[0] + num_experts = w1.size(0) + device = w1.device + + # This is the maximum GroupedGemm M size that we expect to run + # the grouped_gemm with. + MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE, + num_topk, + num_experts, + block_m, + expert_tokens_meta=None) + # Distribute expert-ids evenly. + MAX_BLOCKS = MAX_M // block_m + expert_ids_block = torch.randint(low=0, + high=num_experts, + size=(MAX_BLOCKS, ), + device=device, + dtype=torch.int32) + expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) + + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): + + _, n, k = w.size() + a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) + a1q_scales = torch.empty((MAX_M, k // block_m), + device=device, + dtype=torch.float32) + out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + + pbar = tqdm(total=MAX_BLOCKS, + desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})") + num_tokens = MAX_M + while num_tokens > 0: + m_grouped_fp8_gemm_nt_contiguous( + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), + out[:num_tokens], expert_ids[:num_tokens]) + pbar.update(1) + num_tokens = num_tokens - block_m + + _warmup(w1, w1_scale) + _warmup(w2, w2_scale) + + class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): @@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ): assert self.block_shape is not None assert a1q_scale is not None + assert w1_scale is not None + assert w2_scale is not None + + if not env.VLLM_SKIP_DEEP_GEMM_WARMUP: + # DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing + # to happen during actual model-inference. The + # `warmup_deepgemm_kernels` function is a `run_once` decorated + # function that executes during the model profile run. This warmup + # should create all the required JITs for the current model. + warmup_deepgemm_gg_contiguous_kernels(w1, + w2, + w1_scale, + w2_scale, + num_topk=topk_ids.size(1)) a1q = hidden_states _, N, K = w1.size() diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4dedee2a3f862..8ab34e7505ee2 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -8,6 +8,7 @@ from __future__ import annotations import functools import importlib +import os from typing import Any, Callable, NoReturn import torch @@ -77,6 +78,12 @@ def _lazy_init() -> None: if not has_deep_gemm(): return + # Set up deep_gemm cache path + DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR' + if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): + os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( + envs.VLLM_CACHE_ROOT, "deep_gemm") + _dg = importlib.import_module("deep_gemm") _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",