hack for topk ids

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-06-23 18:02:29 -04:00
parent c40692bf9a
commit 14a6efb83e

View File

@ -1296,6 +1296,34 @@ class FusedMoE(torch.nn.Module):
indices_type: Optional[torch.dtype] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# Number of tokens in the current batch
num_tokens = hidden_states.shape[0]
# Infer how many experts exist from the router-logit dimension
global_num_experts = router_logits.shape[-1]
# Choose a dtype for the indices
if indices_type is None:
indices_type = torch.long
# Random expert IDs, uniform in [0, global_num_experts)
topk_ids = torch.randint(
low=0,
high=global_num_experts,
size=(num_tokens, top_k),
dtype=indices_type,
device=hidden_states.device,
)
# All-ones weights
topk_weights = torch.ones(
(num_tokens, top_k),
dtype=torch.float32,
device=hidden_states.device,
)
return topk_weights, topk_ids
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None