[Perf] Optimize deepgemm experts initialization, 3.9% TTFT improvement (#30494)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: li-jinpeng <3332126450@qq.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Wentao Ye 2025-12-11 17:28:34 -05:00 committed by GitHub
parent 3efdc3feae
commit c817b14151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
offs = start_m + off_expert
mask = offs < cur_expert_token_num
tl.store(
m_indices_start_ptr + start_m + off_expert,
m_indices_start_ptr + offs,
cur_expert,
mask=mask,
)
@ -366,12 +372,17 @@ def deepgemm_moe_permute(
(M_sum, H // block_k), device=device, dtype=torch.float32
)
maybe_has_empty_blocks = (expert_tokens_meta is None) or (
expert_tokens_meta.expert_num_tokens_cpu is None
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
# completely invalid / padded blocks that should be skipped. We always
# initialize expert_ids to -1 so any row that is not explicitly written
# by the scatter kernel will be treated as invalid and skipped by
# DeepGEMM's scheduler.
expert_ids = torch.full(
(M_sum,),
fill_value=-1,
device=device,
dtype=torch.int32,
)
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
expert_num_tokens = None