mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 06:09:09 +08:00
hack for topk ids
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
c40692bf9a
commit
14a6efb83e
@ -1296,6 +1296,34 @@ class FusedMoE(torch.nn.Module):
|
|||||||
indices_type: Optional[torch.dtype] = None):
|
indices_type: Optional[torch.dtype] = None):
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
|
||||||
|
# Number of tokens in the current batch
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
|
# Infer how many experts exist from the router-logit dimension
|
||||||
|
global_num_experts = router_logits.shape[-1]
|
||||||
|
|
||||||
|
# Choose a dtype for the indices
|
||||||
|
if indices_type is None:
|
||||||
|
indices_type = torch.long
|
||||||
|
|
||||||
|
# Random expert IDs, uniform in [0, global_num_experts)
|
||||||
|
topk_ids = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=global_num_experts,
|
||||||
|
size=(num_tokens, top_k),
|
||||||
|
dtype=indices_type,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All-ones weights
|
||||||
|
topk_weights = torch.ones(
|
||||||
|
(num_tokens, top_k),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
# DeekSeekv2 uses grouped_top_k
|
# DeekSeekv2 uses grouped_top_k
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert topk_group is not None
|
assert topk_group is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user