mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 14:43:38 +08:00
hack for topk ids
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
c40692bf9a
commit
14a6efb83e
@ -1296,6 +1296,34 @@ class FusedMoE(torch.nn.Module):
|
||||
indices_type: Optional[torch.dtype] = None):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
|
||||
# Number of tokens in the current batch
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Infer how many experts exist from the router-logit dimension
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# Choose a dtype for the indices
|
||||
if indices_type is None:
|
||||
indices_type = torch.long
|
||||
|
||||
# Random expert IDs, uniform in [0, global_num_experts)
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
high=global_num_experts,
|
||||
size=(num_tokens, top_k),
|
||||
dtype=indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# All-ones weights
|
||||
topk_weights = torch.ones(
|
||||
(num_tokens, top_k),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user