[UX] Reduce DeepGEMM warmup log output to single progress bar (#30903)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-17 23:21:51 -05:00 committed by GitHub
parent 0c738b58bc
commit 4a8412f773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,7 +10,7 @@ import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int):
def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]:
"""Get the M values to warmup for a given weight tensor."""
n, _ = w.size()
device = w.device
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
return _generate_optimal_warmup_m_values(max_tokens, n, device)
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
return list(range(1, max_tokens + 1))
def _deepgemm_fp8_gemm_nt_warmup(
w: torch.Tensor,
ws: torch.Tensor,
max_tokens: int,
pbar: tqdm | None = None,
):
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
return
@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
m_values = list(range(1, max_tokens + 1))
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
pbar = tqdm(total=len(m_values), desc=desc)
m_values = _get_fp8_gemm_nt_m_values(w, max_tokens)
for num_tokens in m_values:
fp8_gemm_nt(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
)
pbar.update(1)
if pbar is not None:
pbar.update(1)
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
def _get_grouped_gemm_params(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
max_tokens: int,
):
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
) -> tuple[int, int, torch.Tensor]:
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = get_mk_alignment_for_contiguous_layout()[0]
@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
return MAX_M, block_m, expert_ids
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
max_tokens: int,
pbar: tqdm | None = None,
):
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens)
device = w1.device
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
# Generate M values in block_m increments (already optimized for MoE)
m_values = list(range(block_m, MAX_M + 1, block_m))
pbar = tqdm(
total=len(m_values),
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
f"[{len(m_values)} values, block_m={block_m}]",
)
for num_tokens in m_values:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
out[:num_tokens],
expert_ids[:num_tokens],
)
pbar.update(1)
if pbar is not None:
pbar.update(1)
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
def deepgemm_fp8_gemm_nt_warmup(
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
):
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
for dgm in dg_modules:
w, ws, _ = _extract_data_from_linear_base_module(dgm)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
model: torch.nn.Module, max_tokens: int
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
):
dg_modules = [
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
dgm
)
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk, max_tokens
w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar
)
def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE)
seen_grouped_sizes: set[torch.Size] = set(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
)
total = 0
for m in model.modules():
if _fp8_linear_may_use_deep_gemm(m):
w, _, _ = _extract_data_from_linear_base_module(m)
if w.size() not in seen_fp8_sizes:
total += len(_get_fp8_gemm_nt_m_values(w, max_tokens))
seen_fp8_sizes.add(w.size())
elif _fused_moe_grouped_gemm_may_use_deep_gemm(m):
w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m)
if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes:
continue
MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens)
n_values = (MAX_M - block_m) // block_m + 1
if w13.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w13.size())
if w2.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w2.size())
return total
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
total = _count_warmup_iterations(model, max_tokens)
if total == 0:
return
# Only show progress bar on rank 0 to avoid cluttered output
if is_global_first_rank():
with tqdm(total=total, desc="DeepGEMM warmup") as pbar:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar)
else:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)