From 81fe3f82af2e30c93ddb106b7a84b7730b982e66 Mon Sep 17 00:00:00 2001 From: usberkeley <150880684+usberkeley@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:48:11 +0800 Subject: [PATCH] [BugFix] Fix index error in ngram_proposer (#29779) Signed-off-by: Bradley --- tests/v1/spec_decode/test_ngram.py | 34 ++++++++++++++++++++++++++- vllm/v1/spec_decode/ngram_proposer.py | 4 ++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 692c39282c372..6bc412abe8695 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,7 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np -from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ( + ModelConfig, + SpeculativeConfig, + VllmConfig, +) from vllm.v1.spec_decode.ngram_proposer import ( NgramProposer, _find_longest_matched_ngram_and_propose_tokens, @@ -167,6 +171,34 @@ def test_ngram_proposer(): assert np.array_equal(result[0], np.array([3, 1])) assert np.array_equal(result[1], np.array([])) + # Test non-contiguous indices: requests 0 and 2 need proposals, + # request 1 is in prefill + proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) + max_model_len = 20 + token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32) + token_ids_cpu[0, :5] = [1, 2, 3, 1, 2] + token_ids_cpu[1, :3] = [4, 5, 6] + token_ids_cpu[2, :5] = [7, 8, 9, 7, 8] + num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32) + sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill + result = proposer.propose( + sampled_token_ids=sampled_token_ids, + req_ids=["0", "1", "2"], + num_tokens_no_spec=num_tokens_no_spec, + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result) == 3 + assert np.array_equal(result[0], [3, 1]) + assert len(result[1]) == 0 + assert np.array_equal(result[2], [9, 7]) + # Verify internal arrays written to correct indices + assert proposer.valid_ngram_num_drafts[0] == 2 + assert proposer.valid_ngram_num_drafts[1] == 0 + assert proposer.valid_ngram_num_drafts[2] == 2 + assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1]) + assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7]) + # test if 0 threads available: can happen if TP size > CPU count ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) ngram_proposer.num_numba_thread_available = 0 diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 10b3f0aa040e5..1273ca12c3600 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -196,9 +196,9 @@ def batch_propose_numba( k=k, ) - valid_ngram_num_drafts[i] = drafter_output.shape[0] + valid_ngram_num_drafts[idx] = drafter_output.shape[0] if len(drafter_output): - valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output + valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output @jit(nopython=True)