[Misc] DeepGEMM : Avoid JIT generation in the hot-path (#22215)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-08-08 19:09:59 -04:00 committed by GitHub
parent cd9b9de1fb
commit f703b923f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 274 additions and 37 deletions

View File

@ -237,18 +237,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1_scale is not None
assert w2_scale is not None
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
# to happen during actual model-inference. The
# `warmup_deepgemm_kernels` function is a `run_once` decorated
# function that executes during the model profile run. This warmup
# should create all the required JITs for the current model.
warmup_deepgemm_gg_contiguous_kernels(w1,
w2,
w1_scale,
w2_scale,
num_topk=topk_ids.size(1))
a1q = hidden_states
_, N, K = w1.size()

View File

@ -4,6 +4,9 @@
import functools
import json
import os
# torch.compile needs typing.List. It will fail torch.library.infer_schema
# otherwise
from typing import List # noqa: UP035
from typing import Any, Callable, Optional
import torch
@ -998,29 +1001,30 @@ def get_config_dtype_str(
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> None:
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
@ -1082,7 +1086,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
@ -1264,7 +1268,8 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None, #noqa: UP006
) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,

View File

@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup deep_gemm kernels.
DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
be used during model execution beforehand.
"""
import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
def _extract_data_from_linear_base_module(
m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
"""
Extract weights, weight scales and quantization block sizes from the given
LinearBase module.
"""
assert isinstance(m, LinearBase)
assert isinstance(m.quant_method, Fp8LinearMethod)
assert m.quant_method.block_quant
assert m.quant_method.quant_config is not None
w = m.weight
ws = m.weight_scale_inv
quant_block_size = m.quant_method.quant_config.weight_block_size
assert isinstance(w, torch.Tensor)
assert isinstance(ws, torch.Tensor)
assert quant_block_size is not None
return (w, ws, quant_block_size)
def _extract_data_from_fused_moe_module(
m: torch.nn.Module
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""
Extract weights, weight scales and num_topk from FusedMoE module.
"""
assert isinstance(m, FusedMoE)
w13 = m.w13_weight
w13_s = m.w13_weight_scale_inv
w2 = m.w2_weight
w2_s = m.w2_weight_scale_inv
num_topk = m.top_k
assert isinstance(w13, torch.Tensor)
assert isinstance(w13_s, torch.Tensor)
assert isinstance(w2, torch.Tensor)
assert isinstance(w2_s, torch.Tensor)
return w13, w13_s, w2, w2_s, num_topk
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
"""
Return True if the input module/layer could be processed with DeepGEMM.
"""
block_size = deep_gemm_block_shape()[0]
if not (isinstance(module, LinearBase)
and isinstance(module.quant_method, Fp8LinearMethod)
and module.quant_method.block_quant):
return False
w, _, block_sizes = _extract_data_from_linear_base_module(module)
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not (isinstance(module, FusedMoE)
and module.moe_config.quant_dtype == torch.float8_e4m3fn
and module.moe_config.block_shape == deep_gemm_block_shape()):
return False
if not isinstance(module.quant_method.fused_experts,
FusedMoEModularKernel):
# fused_experts could invoke deep_gemm_moe_fp8
return True
mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(mk.fused_experts,
(DeepGemmExperts, TritonOrDeepGemmExperts))
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
max_tokens: int):
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
return
n, k = w.size()
block_m = deep_gemm_block_shape()[0]
device = w.device
a1q = torch.empty((max_tokens, k),
device=device,
dtype=torch.float8_e4m3fn)
a1q_scales = torch.empty((max_tokens, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(total=max_tokens,
desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
num_tokens = max_tokens
while num_tokens > 0:
fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
out[:num_tokens])
pbar.update(1)
num_tokens -= 1
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int):
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
return
assert w1.size(0) == w2.size(0), (
"w1 and w2 must have the same number of experts")
block_m = deep_gemm_block_shape()[0]
num_experts = w1.size(0)
device = w1.device
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(low=0,
high=num_experts,
size=(MAX_BLOCKS, ),
device=device,
dtype=torch.int32)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
a1q_scales = torch.empty((MAX_M, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(
total=MAX_BLOCKS,
desc=
f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
)
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
out[:num_tokens], expert_ids[:num_tokens])
pbar.update(1)
num_tokens = num_tokens - block_m
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
_warmup(w, ws)
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
dg_modules = [
m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
]
for dgm in dg_modules:
w, ws, _ = _extract_data_from_linear_base_module(dgm)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
dg_modules = [
m for m in model.modules()
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
]
for dgm in dg_modules:
w13, w13_scale, w2, w2_scale, num_topk = (
_extract_data_from_fused_moe_module(dgm))
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk)
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
import torch
import vllm.envs as envs
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.utils.deep_gemm import is_deep_gemm_supported
def kernel_warmup(model: torch.nn.Module, max_tokens: int):
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
if do_deep_gemm_warmup:
deep_gemm_warmup(model, max_tokens)

View File

@ -21,6 +21,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
@ -338,6 +339,10 @@ class Worker(WorkerBase):
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Warmup kernels used during model execution
kernel_warmup(self.get_model(),
max_tokens=self.scheduler_config.max_num_batched_tokens)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)