From 5c538c37b26f8ebb9250dfc6a1866ec4aaac4299 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 12 Mar 2025 01:12:41 -0400 Subject: [PATCH] [V1][Bugfix][Spec Decode] Fix incorrect outputs in V1 speculative decoding due to batch indexing (#14645) Signed-off-by: Benjamin Chislett --- tests/v1/e2e/test_ngram_spec_decode.py | 61 +++++++++++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 4 +- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index 150caa150a59..519a74cab84b 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest from vllm import LLM, SamplingParams @@ -6,16 +8,41 @@ from vllm import LLM, SamplingParams @pytest.fixture def test_prompts(): - return [ - "Can you repeat the sentence ten times, this is a sentence.", - "Can you repeat the sentence ten times, this is a test.", - ] + prompt_types = ["repeat", "sentence"] + num_prompts = 100 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """ + elif kind == "sentence": + prompt = f""" + please give a ten-word sentence that + uses the word {word} at least once. + give no other output than that simple sentence without quotes. + """ + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append([{"role": "user", "content": prompt}]) + + return prompts @pytest.fixture def sampling_config(): # Only support greedy for now - return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False) + return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) @pytest.fixture @@ -32,18 +59,28 @@ def test_ngram_correctness(monkeypatch, test_prompts, sampling_config, with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name) - ref_outputs = ref_llm.generate(test_prompts, sampling_config) + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm spec_llm = LLM(model=model_name, speculative_model='[ngram]', ngram_prompt_lookup_max=5, ngram_prompt_lookup_min=3, - num_speculative_tokens=3) - spec_outputs = spec_llm.generate(test_prompts, sampling_config) + num_speculative_tokens=3, + max_model_len=1024) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 for ref_output, spec_output in zip(ref_outputs, spec_outputs): - assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ - (f"ref_output: {ref_output.outputs[0].text}," - f"spec_output: {spec_output.outputs[0].text}") + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs)) del spec_llm diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 732792885fb3..df7ca70924bf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1015,11 +1015,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: target_probs = self.model.sampler.compute_probs( logits, sampling_metadata) - scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys( - ) draft_token_ids = [ scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) - for req_id in scheduled_request_ids + for req_id in self.input_batch.req_ids ] sampler_output = self.rejection_sampler(draft_token_ids, target_probs,