[BugFix] Fix index error in ngram_proposer (#29779)

Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
This commit is contained in:
usberkeley 2025-12-02 12:48:11 +08:00 committed by GitHub
parent 53bf71b0f0
commit 81fe3f82af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 3 deletions

View File

@ -2,7 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.config import (
ModelConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.v1.spec_decode.ngram_proposer import (
NgramProposer,
_find_longest_matched_ngram_and_propose_tokens,
@ -167,6 +171,34 @@ def test_ngram_proposer():
assert np.array_equal(result[0], np.array([3, 1]))
assert np.array_equal(result[1], np.array([]))
# Test non-contiguous indices: requests 0 and 2 need proposals,
# request 1 is in prefill
proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
max_model_len = 20
token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32)
token_ids_cpu[0, :5] = [1, 2, 3, 1, 2]
token_ids_cpu[1, :3] = [4, 5, 6]
token_ids_cpu[2, :5] = [7, 8, 9, 7, 8]
num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32)
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
result = proposer.propose(
sampled_token_ids=sampled_token_ids,
req_ids=["0", "1", "2"],
num_tokens_no_spec=num_tokens_no_spec,
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result) == 3
assert np.array_equal(result[0], [3, 1])
assert len(result[1]) == 0
assert np.array_equal(result[2], [9, 7])
# Verify internal arrays written to correct indices
assert proposer.valid_ngram_num_drafts[0] == 2
assert proposer.valid_ngram_num_drafts[1] == 0
assert proposer.valid_ngram_num_drafts[2] == 2
assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1])
assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7])
# test if 0 threads available: can happen if TP size > CPU count
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
ngram_proposer.num_numba_thread_available = 0

View File

@ -196,9 +196,9 @@ def batch_propose_numba(
k=k,
)
valid_ngram_num_drafts[i] = drafter_output.shape[0]
valid_ngram_num_drafts[idx] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output
valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output
@jit(nopython=True)