[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#30632)

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark 2025-12-14 17:32:16 +08:00 committed by GitHub
parent dcb31196da
commit add1b9d3de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]