mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Startup] Make DeepGEMM warmup scale with max-num-batched-tokens (#24693)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
fcba05c435
commit
2e6bc46821
@ -10,6 +10,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M, deep_gemm_block_shape)
|
||||
@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
|
||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int):
|
||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, num_topk: int, max_tokens: int):
|
||||
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
|
||||
return
|
||||
@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
# Assumes all ranks have the same max_num_batched_tokens
|
||||
max_tokens_across_dp = get_dp_group().world_size * max_tokens
|
||||
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
MAX_M = compute_aligned_M(max_tokens,
|
||||
num_topk,
|
||||
num_experts,
|
||||
block_m,
|
||||
@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
|
||||
|
||||
|
||||
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
|
||||
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module,
|
||||
max_tokens: int):
|
||||
dg_modules = [
|
||||
m for m in model.modules()
|
||||
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
||||
@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
|
||||
w13, w13_scale, w2, w2_scale, num_topk = (
|
||||
_extract_data_from_fused_moe_module(dgm))
|
||||
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
w13, w2, w13_scale, w2_scale, num_topk)
|
||||
w13, w2, w13_scale, w2_scale, num_topk, max_tokens)
|
||||
|
||||
|
||||
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user