mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 15:57:00 +08:00
[Misc] Fix expert_ids shape in MoE (#4517)
This commit is contained in:
parent
c9d852d601
commit
826b82a260
@ -203,14 +203,15 @@ def moe_align_block_size(
|
|||||||
- The padding ensures that the total number of tokens is now divisible
|
- The padding ensures that the total number of tokens is now divisible
|
||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
sorted_ids = torch.empty(
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
(topk_ids.numel() + num_experts * (block_size - 1), ),
|
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
|
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
|
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
num_tokens_post_pad = torch.empty((1),
|
num_tokens_post_pad = torch.empty((1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user