mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
[V1][Bugfix][Spec Decode] Fix incorrect outputs in V1 speculative decoding due to batch indexing (#14645)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
e22ee1e7a2
commit
5c538c37b2
@ -1,4 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
@ -6,16 +8,41 @@ from vllm import LLM, SamplingParams
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_prompts():
|
def test_prompts():
|
||||||
return [
|
prompt_types = ["repeat", "sentence"]
|
||||||
"Can you repeat the sentence ten times, this is a sentence.",
|
num_prompts = 100
|
||||||
"Can you repeat the sentence ten times, this is a test.",
|
prompts = []
|
||||||
]
|
|
||||||
|
random.seed(0)
|
||||||
|
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||||
|
|
||||||
|
# Generate a mixed batch of prompts, some of which can be easily
|
||||||
|
# predicted by n-gram matching and some which likely cannot.
|
||||||
|
for kind in random_prompt_type_choices:
|
||||||
|
word_choices = ["test", "temp", "hello", "where"]
|
||||||
|
word = random.choice(word_choices)
|
||||||
|
if kind == "repeat":
|
||||||
|
prompt = f"""
|
||||||
|
please repeat the word '{word}' 10 times.
|
||||||
|
give no other output than the word at least ten times in a row,
|
||||||
|
in lowercase with spaces between each word and without quotes.
|
||||||
|
"""
|
||||||
|
elif kind == "sentence":
|
||||||
|
prompt = f"""
|
||||||
|
please give a ten-word sentence that
|
||||||
|
uses the word {word} at least once.
|
||||||
|
give no other output than that simple sentence without quotes.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prompt type: {kind}")
|
||||||
|
prompts.append([{"role": "user", "content": prompt}])
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sampling_config():
|
def sampling_config():
|
||||||
# Only support greedy for now
|
# Only support greedy for now
|
||||||
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
|
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -32,18 +59,28 @@ def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
|
|||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
|
||||||
spec_llm = LLM(model=model_name,
|
spec_llm = LLM(model=model_name,
|
||||||
speculative_model='[ngram]',
|
speculative_model='[ngram]',
|
||||||
ngram_prompt_lookup_max=5,
|
ngram_prompt_lookup_max=5,
|
||||||
ngram_prompt_lookup_min=3,
|
ngram_prompt_lookup_min=3,
|
||||||
num_speculative_tokens=3)
|
num_speculative_tokens=3,
|
||||||
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
|
max_model_len=1024)
|
||||||
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
(f"ref_output: {ref_output.outputs[0].text},"
|
matches += 1
|
||||||
f"spec_output: {spec_output.outputs[0].text}")
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
|
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||||
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
assert matches > int(0.7 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
|
|||||||
@ -1015,11 +1015,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
target_probs = self.model.sampler.compute_probs(
|
target_probs = self.model.sampler.compute_probs(
|
||||||
logits, sampling_metadata)
|
logits, sampling_metadata)
|
||||||
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
|
|
||||||
)
|
|
||||||
draft_token_ids = [
|
draft_token_ids = [
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||||
for req_id in scheduled_request_ids
|
for req_id in self.input_batch.req_ids
|
||||||
]
|
]
|
||||||
sampler_output = self.rejection_sampler(draft_token_ids,
|
sampler_output = self.rejection_sampler(draft_token_ids,
|
||||||
target_probs,
|
target_probs,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user