diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 74599fa44c88..a25ef86a989d 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -10,6 +10,7 @@ import torch from tqdm import tqdm import vllm.envs as envs +from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deep_gemm_block_shape) @@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, num_topk: int, max_tokens: int): if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): return @@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, num_experts = w1.size(0) device = w1.device + # Assumes all ranks have the same max_num_batched_tokens + max_tokens_across_dp = get_dp_group().world_size * max_tokens + max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, + MAX_M = compute_aligned_M(max_tokens, num_topk, num_experts, block_m, @@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) -def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module, + max_tokens: int): dg_modules = [ m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) @@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): w13, w13_scale, w2, w2_scale, num_topk = ( _extract_data_from_fused_moe_module(dgm)) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk) + w13, w2, w13_scale, w2_scale, num_topk, max_tokens) def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): deepgemm_fp8_gemm_nt_warmup(model, max_tokens) - deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)