From 14a6efb83ef3de660c497ae3c59b54390f9fe99c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 23 Jun 2025 18:02:29 -0400 Subject: [PATCH] hack for topk ids Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/layer.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c1bae033c2b4b..2be69a15f0616 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1296,6 +1296,34 @@ class FusedMoE(torch.nn.Module): indices_type: Optional[torch.dtype] = None): from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + # Number of tokens in the current batch + num_tokens = hidden_states.shape[0] + + # Infer how many experts exist from the router-logit dimension + global_num_experts = router_logits.shape[-1] + + # Choose a dtype for the indices + if indices_type is None: + indices_type = torch.long + + # Random expert IDs, uniform in [0, global_num_experts) + topk_ids = torch.randint( + low=0, + high=global_num_experts, + size=(num_tokens, top_k), + dtype=indices_type, + device=hidden_states.device, + ) + + # All-ones weights + topk_weights = torch.ones( + (num_tokens, top_k), + dtype=torch.float32, + device=hidden_states.device, + ) + + return topk_weights, topk_ids + # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None