diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 936f6b1e28ce1..2bbc655bd935f 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -10,7 +10,7 @@ import torch from tqdm import tqdm import vllm.envs as envs -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod @@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int): +def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]: + """Get the M values to warmup for a given weight tensor.""" + n, _ = w.size() + device = w.device + + # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". + # Otherwise warmup all token sizes to avoid JIT compilation in hotpath + if envs.VLLM_DEEP_GEMM_WARMUP == "relax": + return _generate_optimal_warmup_m_values(max_tokens, n, device) + else: + assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( + "Expected " + 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' + f"{envs.VLLM_DEEP_GEMM_WARMUP}" + ) + return list(range(1, max_tokens + 1)) + + +def _deepgemm_fp8_gemm_nt_warmup( + w: torch.Tensor, + ws: torch.Tensor, + max_tokens: int, + pbar: tqdm | None = None, +): if w.size() in FP8_GEMM_NT_WARMUP_CACHE: return @@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". - # Otherwise warmup all token sizes to avoid JIT compilation in hotpath - if envs.VLLM_DEEP_GEMM_WARMUP == "relax": - m_values = _generate_optimal_warmup_m_values(max_tokens, n, device) - desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]" - else: - assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( - "Expected " - 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' - f"{envs.VLLM_DEEP_GEMM_WARMUP}" - ) - m_values = list(range(1, max_tokens + 1)) - desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]" - - pbar = tqdm(total=len(m_values), desc=desc) + m_values = _get_fp8_gemm_nt_m_values(w, max_tokens) for num_tokens in m_values: fp8_gemm_nt( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] ) - pbar.update(1) + if pbar is not None: + pbar.update(1) FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) @@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( +def _get_grouped_gemm_params( w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, num_topk: int, max_tokens: int, -): - if ( - w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE - and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE - ): - return - +) -> tuple[int, int, torch.Tensor]: assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" block_m = get_mk_alignment_for_contiguous_layout()[0] @@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) + return MAX_M, block_m, expert_ids + + +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int, + max_tokens: int, + pbar: tqdm | None = None, +): + if ( + w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + ): + return + + MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens) + device = w1.device + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): _, n, k = w.size() a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) @@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - # Generate M values in block_m increments (already optimized for MoE) m_values = list(range(block_m, MAX_M + 1, block_m)) - pbar = tqdm( - total=len(m_values), - desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) " - f"[{len(m_values)} values, block_m={block_m}]", - ) - for num_tokens in m_values: m_grouped_fp8_gemm_nt_contiguous( (a1q[:num_tokens], a1q_scales[:num_tokens]), @@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( out[:num_tokens], expert_ids[:num_tokens], ) - pbar.update(1) + if pbar is not None: + pbar.update(1) for w, ws in [(w1, w1_scale), (w2, w2_scale)]: if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: @@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size()) -def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): +def deepgemm_fp8_gemm_nt_warmup( + model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None +): dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)] for dgm in dg_modules: w, ws, _ = _extract_data_from_linear_base_module(dgm) - _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) + _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar) def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - model: torch.nn.Module, max_tokens: int + model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None ): dg_modules = [ m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) @@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( dgm ) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk, max_tokens + w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar ) +def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int: + seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE) + seen_grouped_sizes: set[torch.Size] = set( + GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + ) + + total = 0 + for m in model.modules(): + if _fp8_linear_may_use_deep_gemm(m): + w, _, _ = _extract_data_from_linear_base_module(m) + if w.size() not in seen_fp8_sizes: + total += len(_get_fp8_gemm_nt_m_values(w, max_tokens)) + seen_fp8_sizes.add(w.size()) + elif _fused_moe_grouped_gemm_may_use_deep_gemm(m): + w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m) + if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes: + continue + MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens) + n_values = (MAX_M - block_m) // block_m + 1 + if w13.size() not in seen_grouped_sizes: + total += n_values + seen_grouped_sizes.add(w13.size()) + if w2.size() not in seen_grouped_sizes: + total += n_values + seen_grouped_sizes.add(w2.size()) + return total + + def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): - deepgemm_fp8_gemm_nt_warmup(model, max_tokens) - deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens) + total = _count_warmup_iterations(model, max_tokens) + if total == 0: + return + + # Only show progress bar on rank 0 to avoid cluttered output + if is_global_first_rank(): + with tqdm(total=total, desc="DeepGEMM warmup") as pbar: + deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar) + else: + deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)