From 826b82a260ebb1ea7edd04a3278d5fb9b103a76e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 1 May 2024 16:47:59 -0700 Subject: [PATCH] [Misc] Fix expert_ids shape in MoE (#4517) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b4f81527141a8..3cb0419404625 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -203,14 +203,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1), ), - dtype=torch.int32, - device=topk_ids.device) - expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)