From f703b923f3885157cf02b951c42f967c25329b01 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 8 Aug 2025 19:09:59 -0400 Subject: [PATCH] [Misc] DeepGEMM : Avoid JIT generation in the hot-path (#22215) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- .../layers/fused_moe/deep_gemm_moe.py | 12 - .../layers/fused_moe/fused_moe.py | 55 +++-- .../model_executor/warmup/deep_gemm_warmup.py | 219 ++++++++++++++++++ vllm/model_executor/warmup/kernel_warmup.py | 20 ++ vllm/v1/worker/gpu_worker.py | 5 + 5 files changed, 274 insertions(+), 37 deletions(-) create mode 100644 vllm/model_executor/warmup/deep_gemm_warmup.py create mode 100644 vllm/model_executor/warmup/kernel_warmup.py diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ba7105c83a92..9b8175f42a9d 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -237,18 +237,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w1_scale is not None assert w2_scale is not None - if not env.VLLM_SKIP_DEEP_GEMM_WARMUP: - # DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing - # to happen during actual model-inference. The - # `warmup_deepgemm_kernels` function is a `run_once` decorated - # function that executes during the model profile run. This warmup - # should create all the required JITs for the current model. - warmup_deepgemm_gg_contiguous_kernels(w1, - w2, - w1_scale, - w2_scale, - num_topk=topk_ids.size(1)) - a1q = hidden_states _, N, K = w1.size() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 597af08c3c9f..f4f5457ebcd0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -4,6 +4,9 @@ import functools import json import os +# torch.compile needs typing.List. It will fail torch.library.infer_schema +# otherwise +from typing import List # noqa: UP035 from typing import Any, Callable, Optional import torch @@ -998,29 +1001,30 @@ def get_config_dtype_str( return None -def inplace_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: #noqa: UP006 fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, @@ -1082,7 +1086,7 @@ def flashinfer_fused_moe_blockscale_fp8( intermediate_size: int, expert_offset: int, local_num_experts: int, - block_shape: list[int], + block_shape: List[int], #noqa: UP006 routed_scaling: float = 1.0) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe assert top_k <= global_num_experts @@ -1264,7 +1268,8 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, #noqa: UP006 +) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py new file mode 100644 index 000000000000..74599fa44c88 --- /dev/null +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Warmup deep_gemm kernels. +DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would +be used during model execution beforehand. +""" + +import torch +from tqdm import tqdm + +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + compute_aligned_M, deep_gemm_block_shape) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous + + +def _extract_data_from_linear_base_module( + m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Extract weights, weight scales and quantization block sizes from the given + LinearBase module. + """ + assert isinstance(m, LinearBase) + assert isinstance(m.quant_method, Fp8LinearMethod) + assert m.quant_method.block_quant + assert m.quant_method.quant_config is not None + + w = m.weight + ws = m.weight_scale_inv + quant_block_size = m.quant_method.quant_config.weight_block_size + + assert isinstance(w, torch.Tensor) + assert isinstance(ws, torch.Tensor) + assert quant_block_size is not None + return (w, ws, quant_block_size) + + +def _extract_data_from_fused_moe_module( + m: torch.nn.Module +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + Extract weights, weight scales and num_topk from FusedMoE module. + """ + assert isinstance(m, FusedMoE) + w13 = m.w13_weight + w13_s = m.w13_weight_scale_inv + w2 = m.w2_weight + w2_s = m.w2_weight_scale_inv + num_topk = m.top_k + + assert isinstance(w13, torch.Tensor) + assert isinstance(w13_s, torch.Tensor) + assert isinstance(w2, torch.Tensor) + assert isinstance(w2_s, torch.Tensor) + return w13, w13_s, w2, w2_s, num_topk + + +def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: + """ + Return True if the input module/layer could be processed with DeepGEMM. + """ + block_size = deep_gemm_block_shape()[0] + if not (isinstance(module, LinearBase) + and isinstance(module.quant_method, Fp8LinearMethod) + and module.quant_method.block_quant): + return False + + w, _, block_sizes = _extract_data_from_linear_base_module(module) + return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 + and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) + + +def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (isinstance(module, FusedMoE) + and module.moe_config.quant_dtype == torch.float8_e4m3fn + and module.moe_config.block_shape == deep_gemm_block_shape()): + return False + + if not isinstance(module.quant_method.fused_experts, + FusedMoEModularKernel): + # fused_experts could invoke deep_gemm_moe_fp8 + return True + + mk: FusedMoEModularKernel = module.quant_method.fused_experts + # Further check if the ModularKernel implementation uses the DeepGemmExperts + return isinstance(mk.fused_experts, + (DeepGemmExperts, TritonOrDeepGemmExperts)) + + +FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() + + +def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, + max_tokens: int): + if w.size() in FP8_GEMM_NT_WARMUP_CACHE: + return + + n, k = w.size() + block_m = deep_gemm_block_shape()[0] + + device = w.device + a1q = torch.empty((max_tokens, k), + device=device, + dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty((max_tokens, k // block_m), + device=device, + dtype=torch.float32) + out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) + + pbar = tqdm(total=max_tokens, + desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") + num_tokens = max_tokens + while num_tokens > 0: + fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), + out[:num_tokens]) + pbar.update(1) + num_tokens -= 1 + + FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) + + +GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() + + +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int): + if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): + return + + assert w1.size(0) == w2.size(0), ( + "w1 and w2 must have the same number of experts") + + block_m = deep_gemm_block_shape()[0] + num_experts = w1.size(0) + device = w1.device + + # This is the maximum GroupedGemm M size that we expect to run + # the grouped_gemm with. + MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, + num_topk, + num_experts, + block_m, + expert_tokens_meta=None) + # Distribute expert-ids evenly. + MAX_BLOCKS = MAX_M // block_m + expert_ids_block = torch.randint(low=0, + high=num_experts, + size=(MAX_BLOCKS, ), + device=device, + dtype=torch.int32) + expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) + + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): + + _, n, k = w.size() + a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty((MAX_M, k // block_m), + device=device, + dtype=torch.float32) + out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + + pbar = tqdm( + total=MAX_BLOCKS, + desc= + f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})" + ) + num_tokens = MAX_M + while num_tokens > 0: + m_grouped_fp8_gemm_nt_contiguous( + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), + out[:num_tokens], expert_ids[:num_tokens]) + pbar.update(1) + num_tokens = num_tokens - block_m + + for w, ws in [(w1, w1_scale), (w2, w2_scale)]: + if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: + _warmup(w, ws) + GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size()) + + +def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): + dg_modules = [ + m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m) + ] + + for dgm in dg_modules: + w, ws, _ = _extract_data_from_linear_base_module(dgm) + _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) + + +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): + dg_modules = [ + m for m in model.modules() + if _fused_moe_grouped_gemm_may_use_deep_gemm(m) + ] + + for dgm in dg_modules: + w13, w13_scale, w2, w2_scale, num_topk = ( + _extract_data_from_fused_moe_module(dgm)) + _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w13, w2, w13_scale, w2_scale, num_topk) + + +def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): + deepgemm_fp8_gemm_nt_warmup(model, max_tokens) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py new file mode 100644 index 000000000000..10f2dc0252a1 --- /dev/null +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Warmup kernels used during model execution. +This is useful specifically for JIT'ed kernels as we don't want JIT'ing to +happen during model execution. +""" +import torch + +import vllm.envs as envs +from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.utils.deep_gemm import is_deep_gemm_supported + + +def kernel_warmup(model: torch.nn.Module, max_tokens: int): + do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM + and is_deep_gemm_supported() + and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) + if do_deep_gemm_warmup: + deep_gemm_warmup(model, max_tokens) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 7fca245c1bef..0ea23921a080 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -21,6 +21,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask @@ -338,6 +339,10 @@ class Worker(WorkerBase): self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) + # Warmup kernels used during model execution + kernel_warmup(self.get_model(), + max_tokens=self.scheduler_config.max_num_batched_tokens) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)