mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Misc] DeepGEMM : Avoid JIT generation in the hot-path (#22215)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
cd9b9de1fb
commit
f703b923f3
@ -237,18 +237,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
|
||||
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
|
||||
# to happen during actual model-inference. The
|
||||
# `warmup_deepgemm_kernels` function is a `run_once` decorated
|
||||
# function that executes during the model profile run. This warmup
|
||||
# should create all the required JITs for the current model.
|
||||
warmup_deepgemm_gg_contiguous_kernels(w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
num_topk=topk_ids.size(1))
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
# torch.compile needs typing.List. It will fail torch.library.infer_schema
|
||||
# otherwise
|
||||
from typing import List # noqa: UP035
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -998,29 +1001,30 @@ def get_config_dtype_str(
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
def inplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, is_act_and_mul,
|
||||
apply_router_weight_on_input, use_fp8_w8a8,
|
||||
@ -1082,7 +1086,7 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
block_shape: List[int], #noqa: UP006
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
assert top_k <= global_num_experts
|
||||
@ -1264,7 +1268,8 @@ def outplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
|
||||
219
vllm/model_executor/warmup/deep_gemm_warmup.py
Normal file
219
vllm/model_executor/warmup/deep_gemm_warmup.py
Normal file
@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Warmup deep_gemm kernels.
|
||||
DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
|
||||
be used during model execution beforehand.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M, deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
|
||||
def _extract_data_from_linear_base_module(
|
||||
m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||
"""
|
||||
Extract weights, weight scales and quantization block sizes from the given
|
||||
LinearBase module.
|
||||
"""
|
||||
assert isinstance(m, LinearBase)
|
||||
assert isinstance(m.quant_method, Fp8LinearMethod)
|
||||
assert m.quant_method.block_quant
|
||||
assert m.quant_method.quant_config is not None
|
||||
|
||||
w = m.weight
|
||||
ws = m.weight_scale_inv
|
||||
quant_block_size = m.quant_method.quant_config.weight_block_size
|
||||
|
||||
assert isinstance(w, torch.Tensor)
|
||||
assert isinstance(ws, torch.Tensor)
|
||||
assert quant_block_size is not None
|
||||
return (w, ws, quant_block_size)
|
||||
|
||||
|
||||
def _extract_data_from_fused_moe_module(
|
||||
m: torch.nn.Module
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Extract weights, weight scales and num_topk from FusedMoE module.
|
||||
"""
|
||||
assert isinstance(m, FusedMoE)
|
||||
w13 = m.w13_weight
|
||||
w13_s = m.w13_weight_scale_inv
|
||||
w2 = m.w2_weight
|
||||
w2_s = m.w2_weight_scale_inv
|
||||
num_topk = m.top_k
|
||||
|
||||
assert isinstance(w13, torch.Tensor)
|
||||
assert isinstance(w13_s, torch.Tensor)
|
||||
assert isinstance(w2, torch.Tensor)
|
||||
assert isinstance(w2_s, torch.Tensor)
|
||||
return w13, w13_s, w2, w2_s, num_topk
|
||||
|
||||
|
||||
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Return True if the input module/layer could be processed with DeepGEMM.
|
||||
"""
|
||||
block_size = deep_gemm_block_shape()[0]
|
||||
if not (isinstance(module, LinearBase)
|
||||
and isinstance(module.quant_method, Fp8LinearMethod)
|
||||
and module.quant_method.block_quant):
|
||||
return False
|
||||
|
||||
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
||||
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
|
||||
and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
|
||||
|
||||
|
||||
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
if not (isinstance(module, FusedMoE)
|
||||
and module.moe_config.quant_dtype == torch.float8_e4m3fn
|
||||
and module.moe_config.block_shape == deep_gemm_block_shape()):
|
||||
return False
|
||||
|
||||
if not isinstance(module.quant_method.fused_experts,
|
||||
FusedMoEModularKernel):
|
||||
# fused_experts could invoke deep_gemm_moe_fp8
|
||||
return True
|
||||
|
||||
mk: FusedMoEModularKernel = module.quant_method.fused_experts
|
||||
# Further check if the ModularKernel implementation uses the DeepGemmExperts
|
||||
return isinstance(mk.fused_experts,
|
||||
(DeepGemmExperts, TritonOrDeepGemmExperts))
|
||||
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
|
||||
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
|
||||
max_tokens: int):
|
||||
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
|
||||
return
|
||||
|
||||
n, k = w.size()
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
|
||||
device = w.device
|
||||
a1q = torch.empty((max_tokens, k),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
a1q_scales = torch.empty((max_tokens, k // block_m),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(total=max_tokens,
|
||||
desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
|
||||
num_tokens = max_tokens
|
||||
while num_tokens > 0:
|
||||
fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
|
||||
out[:num_tokens])
|
||||
pbar.update(1)
|
||||
num_tokens -= 1
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
||||
|
||||
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
|
||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int):
|
||||
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
|
||||
return
|
||||
|
||||
assert w1.size(0) == w2.size(0), (
|
||||
"w1 and w2 must have the same number of experts")
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
num_topk,
|
||||
num_experts,
|
||||
block_m,
|
||||
expert_tokens_meta=None)
|
||||
# Distribute expert-ids evenly.
|
||||
MAX_BLOCKS = MAX_M // block_m
|
||||
expert_ids_block = torch.randint(low=0,
|
||||
high=num_experts,
|
||||
size=(MAX_BLOCKS, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||
|
||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||
|
||||
_, n, k = w.size()
|
||||
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
|
||||
a1q_scales = torch.empty((MAX_M, k // block_m),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(
|
||||
total=MAX_BLOCKS,
|
||||
desc=
|
||||
f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
|
||||
)
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
|
||||
out[:num_tokens], expert_ids[:num_tokens])
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
||||
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
||||
_warmup(w, ws)
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
|
||||
|
||||
|
||||
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
dg_modules = [
|
||||
m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
|
||||
]
|
||||
|
||||
for dgm in dg_modules:
|
||||
w, ws, _ = _extract_data_from_linear_base_module(dgm)
|
||||
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
|
||||
|
||||
|
||||
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
|
||||
dg_modules = [
|
||||
m for m in model.modules()
|
||||
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
||||
]
|
||||
|
||||
for dgm in dg_modules:
|
||||
w13, w13_scale, w2, w2_scale, num_topk = (
|
||||
_extract_data_from_fused_moe_module(dgm))
|
||||
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
w13, w2, w13_scale, w2_scale, num_topk)
|
||||
|
||||
|
||||
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)
|
||||
20
vllm/model_executor/warmup/kernel_warmup.py
Normal file
20
vllm/model_executor/warmup/kernel_warmup.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Warmup kernels used during model execution.
|
||||
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
|
||||
happen during model execution.
|
||||
"""
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
|
||||
def kernel_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
|
||||
if do_deep_gemm_warmup:
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
@ -21,6 +21,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
@ -338,6 +339,10 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Warmup kernels used during model execution
|
||||
kernel_warmup(self.get_model(),
|
||||
max_tokens=self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user