mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 15:00:38 +08:00
[UX] Reduce DeepGEMM warmup log output to single progress bar (#30903)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
0c738b58bc
commit
4a8412f773
@ -10,7 +10,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
||||
@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
|
||||
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int):
|
||||
def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]:
|
||||
"""Get the M values to warmup for a given weight tensor."""
|
||||
n, _ = w.size()
|
||||
device = w.device
|
||||
|
||||
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
||||
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
||||
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
||||
return _generate_optimal_warmup_m_values(max_tokens, n, device)
|
||||
else:
|
||||
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
||||
"Expected "
|
||||
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
||||
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
||||
)
|
||||
return list(range(1, max_tokens + 1))
|
||||
|
||||
|
||||
def _deepgemm_fp8_gemm_nt_warmup(
|
||||
w: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
max_tokens: int,
|
||||
pbar: tqdm | None = None,
|
||||
):
|
||||
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
|
||||
return
|
||||
|
||||
@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
)
|
||||
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
||||
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
||||
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
||||
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
|
||||
else:
|
||||
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
||||
"Expected "
|
||||
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
||||
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
||||
)
|
||||
m_values = list(range(1, max_tokens + 1))
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
|
||||
|
||||
pbar = tqdm(total=len(m_values), desc=desc)
|
||||
m_values = _get_fp8_gemm_nt_m_values(w, max_tokens)
|
||||
|
||||
for num_tokens in m_values:
|
||||
fp8_gemm_nt(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
||||
)
|
||||
pbar.update(1)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
||||
|
||||
@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
|
||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
def _get_grouped_gemm_params(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int,
|
||||
max_tokens: int,
|
||||
):
|
||||
if (
|
||||
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
):
|
||||
return
|
||||
|
||||
) -> tuple[int, int, torch.Tensor]:
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
)
|
||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||
|
||||
return MAX_M, block_m, expert_ids
|
||||
|
||||
|
||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int,
|
||||
max_tokens: int,
|
||||
pbar: tqdm | None = None,
|
||||
):
|
||||
if (
|
||||
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
):
|
||||
return
|
||||
|
||||
MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens)
|
||||
device = w1.device
|
||||
|
||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||
_, n, k = w.size()
|
||||
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
|
||||
@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
# Generate M values in block_m increments (already optimized for MoE)
|
||||
m_values = list(range(block_m, MAX_M + 1, block_m))
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(m_values),
|
||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
|
||||
f"[{len(m_values)} values, block_m={block_m}]",
|
||||
)
|
||||
|
||||
for num_tokens in m_values:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
||||
@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
out[:num_tokens],
|
||||
expert_ids[:num_tokens],
|
||||
)
|
||||
pbar.update(1)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
||||
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
||||
@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
|
||||
|
||||
|
||||
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
def deepgemm_fp8_gemm_nt_warmup(
|
||||
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
|
||||
):
|
||||
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
|
||||
|
||||
for dgm in dg_modules:
|
||||
w, ws, _ = _extract_data_from_linear_base_module(dgm)
|
||||
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
|
||||
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar)
|
||||
|
||||
|
||||
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
model: torch.nn.Module, max_tokens: int
|
||||
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
|
||||
):
|
||||
dg_modules = [
|
||||
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
||||
@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
dgm
|
||||
)
|
||||
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
w13, w2, w13_scale, w2_scale, num_topk, max_tokens
|
||||
w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar
|
||||
)
|
||||
|
||||
|
||||
def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
|
||||
seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE)
|
||||
seen_grouped_sizes: set[torch.Size] = set(
|
||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||
)
|
||||
|
||||
total = 0
|
||||
for m in model.modules():
|
||||
if _fp8_linear_may_use_deep_gemm(m):
|
||||
w, _, _ = _extract_data_from_linear_base_module(m)
|
||||
if w.size() not in seen_fp8_sizes:
|
||||
total += len(_get_fp8_gemm_nt_m_values(w, max_tokens))
|
||||
seen_fp8_sizes.add(w.size())
|
||||
elif _fused_moe_grouped_gemm_may_use_deep_gemm(m):
|
||||
w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m)
|
||||
if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes:
|
||||
continue
|
||||
MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens)
|
||||
n_values = (MAX_M - block_m) // block_m + 1
|
||||
if w13.size() not in seen_grouped_sizes:
|
||||
total += n_values
|
||||
seen_grouped_sizes.add(w13.size())
|
||||
if w2.size() not in seen_grouped_sizes:
|
||||
total += n_values
|
||||
seen_grouped_sizes.add(w2.size())
|
||||
return total
|
||||
|
||||
|
||||
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
|
||||
total = _count_warmup_iterations(model, max_tokens)
|
||||
if total == 0:
|
||||
return
|
||||
|
||||
# Only show progress bar on rank 0 to avoid cluttered output
|
||||
if is_global_first_rank():
|
||||
with tqdm(total=total, desc="DeepGEMM warmup") as pbar:
|
||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar)
|
||||
else:
|
||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None)
|
||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user