From 0d21b9b51eccabfa1f8114eab2df61d75459bee7 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 13 Oct 2025 10:59:27 -0400 Subject: [PATCH] [UX] Speedup DeepGEMM warmup with heuristics (#25619) Signed-off-by: mgoin Signed-off-by: Michael Goin Signed-off-by: Varun Sundar Rabindranath Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Varun Sundar Rabindranath --- vllm/envs.py | 24 +++++- .../model_executor/warmup/deep_gemm_warmup.py | 83 +++++++++++++++++-- vllm/model_executor/warmup/kernel_warmup.py | 2 +- 3 files changed, 95 insertions(+), 14 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 97076bec11b81..c3686477d88d1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -146,7 +146,11 @@ if TYPE_CHECKING: VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True - VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_DEEP_GEMM_WARMUP: Literal[ + "skip", + "full", + "relax", + ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False @@ -1088,9 +1092,21 @@ environment_variables: dict[str, Callable[[], Any]] = { # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine # startup time by a couple of minutes. - # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. - "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool( - int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0")) + # Available options: + # - "skip" : Skip warmup. + # - "full" : Warmup deepgemm by running all possible gemm shapes the + # engine could encounter. + # - "relax" : Select gemm shapes to run based on some heuristics. The + # heuristic aims to have the same effect as running all possible gemm + # shapes, but provides no guarantees. + "VLLM_DEEP_GEMM_WARMUP": env_with_choices( + "VLLM_DEEP_GEMM_WARMUP", + "relax", + [ + "skip", + "full", + "relax", + ], ), # Whether to use fused grouped_topk used for MoE expert selection. "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 1747caf26cef9..f1ed2696a0967 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -26,6 +26,55 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous +def _generate_optimal_warmup_m_values( + max_tokens: int, n: int, device: torch.device +) -> list[int]: + """ + Generate M values that cover all possible DeepGEMM kernel configurations. + Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp + + Args: + max_tokens: Maximum number of tokens to warmup for + n: The actual N dimension from the weight tensor + device: The torch device to get properties from. + """ + + def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + # DeepGEMM's possible block sizes + block_ms = [64, 128, 256] + block_ns = list(range(16, min(257, n + 1), 16)) + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + + m_values = set() + + # Always include small cases + m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)]) + + # Collect M values where different wave patterns occur + for block_m in block_ms: + for block_n in block_ns: + if block_n > n: + continue + + # Add key M boundaries for this block combination + for wave in range(1, 11): # Up to 10 waves + # M where this block config transitions to next wave + target_blocks = wave * num_sms + m = target_blocks * block_m // ceil_div(n, block_n) + if 1 <= m <= max_tokens: + m_values.add(m) + + # Add block_m boundaries + for multiple in range(1, max_tokens // block_m + 1): + m = multiple * block_m + if m <= max_tokens: + m_values.add(m) + + return sorted(m_values) + + def _extract_data_from_linear_base_module( m: torch.nn.Module, ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: @@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") - num_tokens = max_tokens - while num_tokens > 0: + # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". + # Otherwise warmup all token sizes to avoid JIT compilation in hotpath + if envs.VLLM_DEEP_GEMM_WARMUP == "relax": + m_values = _generate_optimal_warmup_m_values(max_tokens, n, device) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]" + else: + assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( + "Expected " + 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' + f"{envs.VLLM_DEEP_GEMM_WARMUP}" + ) + m_values = list(range(1, max_tokens + 1)) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]" + + pbar = tqdm(total=len(m_values), desc=desc) + + for num_tokens in m_values: fp8_gemm_nt( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] ) pbar.update(1) - num_tokens -= 1 FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) @@ -195,12 +257,16 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + # Generate M values in block_m increments (already optimized for MoE) + m_values = list(range(block_m, MAX_M + 1, block_m)) + pbar = tqdm( - total=MAX_BLOCKS, - desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", + total=len(m_values), + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) " + f"[{len(m_values)} values, block_m={block_m}]", ) - num_tokens = MAX_M - while num_tokens > 0: + + for num_tokens in m_values: m_grouped_fp8_gemm_nt_contiguous( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), @@ -208,7 +274,6 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( expert_ids[:num_tokens], ) pbar.update(1) - num_tokens = num_tokens - block_m for w, ws in [(w1, w1_scale), (w2, w2_scale)]: if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 23227065ee950..28792338f036f 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"): do_deep_gemm_warmup = ( envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() - and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP + and envs.VLLM_DEEP_GEMM_WARMUP != "skip" ) if do_deep_gemm_warmup: model = worker.get_model()