mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Bugfix] Fix deepseekv3 grouped topk error (#13474)
Signed-off-by: Chen-XiaoBing <chenxb002@whu.edu.cn>
This commit is contained in:
parent
992e5c3d34
commit
ed6e9075d3
@ -939,15 +939,17 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
num_token = scores.shape[0]
|
||||
group_scores = scores.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values # [n, n_group]
|
||||
group_scores = (scores.view(num_token, num_expert_group,
|
||||
-1).topk(2, dim=-1)[0].sum(dim=-1))
|
||||
else:
|
||||
group_scores = scores.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
||||
sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
@ -955,7 +957,8 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user