From 56e544f24b62d647ba05cccf51756c2ebebae66f Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sat, 26 Jul 2025 10:08:29 -0400 Subject: [PATCH] [Refactor] Remove `moe_align_block_size_triton` (#21335) Signed-off-by: yewentao256 --- .../kernels/benchmark_moe_align_block_size.py | 90 +---------- .../layers/fused_moe/moe_align_block_size.py | 140 +----------------- 2 files changed, 6 insertions(+), 224 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py index 1af5a21caf465..f540cff6261a8 100644 --- a/benchmarks/kernels/benchmark_moe_align_block_size.py +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -5,9 +5,8 @@ import itertools import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size_triton, + moe_align_block_size, ) from vllm.triton_utils import triton @@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: ) -def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): - """ - Verifies vllm vs. Triton - """ - topk_ids = get_topk_ids(num_tokens, num_experts, topk) - - # 1. malloc space for triton and vllm - # malloc enough space (max_num_tokens_padded) for the sorted ids - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids_triton = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device="cuda" - ) - expert_ids_triton = torch.empty( - (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") - - sorted_ids_vllm = torch.empty_like(sorted_ids_triton) - expert_ids_vllm = torch.empty_like(expert_ids_triton) - num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) - - # 2. run implementations - moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids_triton, - expert_ids_triton, - num_tokens_post_pad_triton, - ) - - ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids_vllm, - expert_ids_vllm, - num_tokens_post_pad_vllm, - ) - print(f"✅ VLLM implementation works with {num_experts} experts!") - - # 3. compare results - if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose( - num_tokens_post_pad_triton, num_tokens_post_pad_vllm - ): - print("✅ Triton and VLLM implementations match.") - else: - print("❌ Triton and VLLM implementations DO NOT match.") - print("Triton expert_ids:", expert_ids_triton) - print("VLLM expert_ids:", expert_ids_vllm) - print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) - print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) - - # test configurations num_tokens_range = [1, 16, 256, 4096] num_experts_range = [16, 64, 224, 256, 280, 512] @@ -87,8 +32,8 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range x_names=["num_tokens", "num_experts", "topk"], x_vals=configs, line_arg="provider", - line_vals=["vllm", "triton"], # "triton" - line_names=["VLLM", "Triton"], # "Triton" + line_vals=["vllm"], + line_names=["vLLM"], plot_name="moe-align-block-size-performance", args={}, ) @@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider): block_size = 256 topk_ids = get_topk_ids(num_tokens, num_experts, topk) - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") - num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") - quantiles = [0.5, 0.2, 0.8] if provider == "vllm": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), - quantiles=quantiles, - ) - elif provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), + lambda: moe_align_block_size(topk_ids, block_size, num_experts), quantiles=quantiles, ) @@ -151,6 +71,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - print("Running correctness check...") - check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) benchmark.run(print_data=True, show_plots=True) diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 2c9ad509fa98e..c7d7126bab3ad 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -5,144 +5,8 @@ from typing import Optional import torch from vllm import _custom_ops as ops -from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, round_up - - -@triton.jit -def moe_align_block_size_stage1( - topk_ids_ptr, - tokens_cnts_ptr, - num_experts: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - - start_idx = pid * tokens_per_thread - - off_c = (pid + 1) * num_experts - - for i in range(tokens_per_thread): - if start_idx + i < numel: - idx = tl.load(topk_ids_ptr + start_idx + i) - token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) - tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) - - -@triton.jit -def moe_align_block_size_stage2( - tokens_cnts_ptr, - num_experts: tl.constexpr, -): - pid = tl.program_id(0) - - last_cnt = 0 - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) - last_cnt = last_cnt + token_cnt - tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) - - -@triton.jit -def moe_align_block_size_stage3( - total_tokens_post_pad_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, -): - last_cumsum = 0 - off_cnt = num_experts * num_experts - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) - last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size - tl.store(cumsum_ptr + i, last_cumsum) - tl.store(total_tokens_post_pad_ptr, last_cumsum) - - -@triton.jit -def moe_align_block_size_stage4( - topk_ids_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - start_idx = tl.load(cumsum_ptr + pid) - end_idx = tl.load(cumsum_ptr + pid + 1) - - for i in range(start_idx, end_idx, block_size): - tl.store(expert_ids_ptr + i // block_size, pid) - - start_idx = pid * tokens_per_thread - off_t = pid * num_experts - - for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, - numel)): - expert_id = tl.load(topk_ids_ptr + i) - token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) - rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) - tl.store(sorted_token_ids_ptr + rank_post_pad, i) - tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) - - -# Triton implementation based on: -# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 -def moe_align_block_size_triton( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - numel = topk_ids.numel() - grid = (num_experts, ) - tokens_cnts = torch.zeros((num_experts + 1, num_experts), - dtype=torch.int32, - device=topk_ids.device) - cumsum = torch.zeros((num_experts + 1, ), - dtype=torch.int32, - device=topk_ids.device) - tokens_per_thread = cdiv(numel, num_experts) - sorted_token_ids.fill_(numel) - expert_ids.zero_() - - moe_align_block_size_stage1[grid]( - topk_ids, - tokens_cnts, - num_experts, - numel, - tokens_per_thread, - ) - moe_align_block_size_stage2[grid]( - tokens_cnts, - num_experts, - ) - moe_align_block_size_stage3[(1, )]( - num_tokens_post_pad, - tokens_cnts, - cumsum, - num_experts, - block_size, - ) - moe_align_block_size_stage4[grid]( - topk_ids, - sorted_token_ids, - expert_ids, - tokens_cnts, - cumsum, - num_experts, - block_size, - numel, - tokens_per_thread, - ) +from vllm.triton_utils import triton +from vllm.utils import round_up def moe_align_block_size(