[Misc] Fix expert_ids shape in MoE (#4517)

This commit is contained in:
Woosuk Kwon 2024-05-01 16:47:59 -07:00 committed by GitHub
parent c9d852d601
commit 826b82a260
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -203,14 +203,15 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)