[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
Nicolò Lucchesi 2025-01-27 22:38:35 +01:00 committed by GitHub
parent 2bc3fbba0c
commit 6116ca8cd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 468 additions and 165 deletions

View File

@ -2,6 +2,7 @@ from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import pytest import pytest
import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1 assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0 assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs assert spec_pos_logprob_token_id in baseline_pos_logprobs
@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
batch_size: int, batch_size: int,
max_output_len: int, max_output_len: int,
seed: int = 0, seed: int = 0,
temperature: float = 0.0): temperature: float = 0.0,
logprobs: Optional[int] = None):
"""Helper method that compares the outputs of both the baseline LLM and """Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero. the same when temperature is zero.
@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
results = [] results = []
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
for args, env in ((arg1, env1), (arg2, env2)): for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model, with RemoteOpenAIServer(model,
args, args,
@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
prompt=prompts, prompt=prompts,
max_tokens=max_output_len, max_tokens=max_output_len,
seed=seed, seed=seed,
temperature=temperature) temperature=temperature,
logprobs=logprobs)
results.append({ results.append({
"test": "test":
"seeded_sampling", "seeded_sampling",
"text": [choice.text for choice in completion.choices], "text": [choice.text for choice in completion.choices],
"logprobs": [choice.logprobs for choice in completion.choices],
"finish_reason": "finish_reason":
[choice.finish_reason for choice in completion.choices], [choice.finish_reason for choice in completion.choices],
"usage": "usage":
@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
n = len(results) // 2 n = len(results) // 2
arg1_results = results[:n] arg1_results = results[:n]
arg2_results = results[n:] arg2_results = results[n:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
for arg1_result, arg2_result in zip(arg1_results, arg2_results): for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, ( assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. " f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}") f"{arg1_result=} != {arg2_result=}")
if logprobs:
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
for l1, l2 in zip(logs1, logs2):
assert l1.tokens == l2.tokens

View File

@ -2,6 +2,8 @@
tensor parallelism. tensor parallelism.
""" """
from typing import Optional
import pytest import pytest
import torch import torch
@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
"--speculative-draft-tensor-parallel-size", "--speculative-draft-tensor-parallel-size",
"1", "1",
])]) ])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
logprobs: Optional[int],
batch_size: int, seed: int): batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for """Verify spec decode works well with same and different TP size for
the draft model with chunked prefill. the draft model with chunked prefill.
""" """
if logprobs:
test_llm_kwargs.extend(
["--disable_logprobs_during_spec_decoding", "False"])
run_equality_correctness_test_tp(model, run_equality_correctness_test_tp(model,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
batch_size, batch_size,
max_output_len=32, max_output_len=32,
seed=seed, seed=seed,
temperature=0.0) temperature=0.0,
logprobs=logprobs)

View File

@ -4,26 +4,27 @@ import pytest
from vllm import SamplingParams from vllm import SamplingParams
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model_name": "JackFram/llama-68m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs",
[{ [{
"speculative_model": "JackFram/llama-160m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs_during_spec_decoding": False,
}, { }, {
"speculative_model": "JackFram/llama-160m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs_during_spec_decoding": True,
}]) }])
@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6]) @pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
def test_logprobs_equality(vllm_runner, common_llm_kwargs, def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int): seed: int, logprobs: int, prefill_chunk_size: int):
"""Verify output logprobs are equal with and without speculative decoding. """Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,

View File

@ -21,6 +21,7 @@ correctess for the target model outputs.
import pytest import pytest
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
# main model # main model
@ -67,12 +68,14 @@ PRECISION = "float32"
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
seed: int): seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6]) @pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
seed: int, logprobs: int): seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_cuda_graph( def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int): seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different """Verify greedy equality with cuda graph enabled and different
batch sizes.""" batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_with_preemption( def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int): seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid- """Verify greedy equality, even when some sequences are preempted mid-
generation. generation.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_different_k(vllm_runner, common_llm_kwargs, def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int, test_llm_kwargs, batch_size: int, output_len: int,
seed: int): seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality """Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens. to without spec decode with different values of num_speculative_tokens.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality """Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large to without spec decode when speculation is disabled for large
batch sizes. batch sizes.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output """Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer. with batch expansion scorer and mqa scorer.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,

View File

@ -25,6 +25,7 @@ import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
# main model # main model
@ -66,14 +67,16 @@ PRECISION = "float32"
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [4, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
seed: int): seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6]) @pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("output_len", [2048]) @pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int): batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output """Verify acceptance rate with different batch size and large output
length.""" length."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("output_len", [64]) @pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [1.0]) @pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
temperature: float, seed: int): temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output.""" """Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
128, 128,
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption( def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int): prefill_chunk_size: int, seed: int):
"""Verify greedy equality, even when some sequences are preempted mid- """Verify greedy equality, even when some sequences are preempted mid-
generation. generation.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_correctness_with_padding( def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int): prefill_chunk_size: int, seed: int):
"""Verify greedy equality when the vocab dimension is padded """Verify greedy equality when the vocab dimension is padded
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# Default pad_to is 64, test model has vocab_size of 32000 # Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None): def patched_pad_vocab_size(vocab_size, pad_to=None):
@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding(
# Use smaller output len for fast test. # Use smaller output len for fast test.
32, 32,
]) ])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs, def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int, test_llm_kwargs, batch_size: int,
output_len: int): prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality """Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens. to without spec decode with different values of num_speculative_tokens.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test. # Use smaller output len for fast test.
32, 32,
]) ])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int, test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int): output_len: int):
"""Verify that mlp speculative decoding produces exact equality """Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large to without spec decode when speculation is disabled for large
batch sizes. batch sizes.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test. # Use smaller output len for fast test.
32, 32,
]) ])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, prefill_chunk_size: int, seed: int):
"""Verify that speculative decoding generates the same output """Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer. with batch expansion scorer and mqa scorer.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,

View File

@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs",
{ [{
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, "disable_logprobs_during_spec_decoding": False
{ }, {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 3,
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
}, "disable_logprobs_during_spec_decoding": False
]) }])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
[ [
@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
seed=seed, seed=seed,
prompt_logprobs=2,
logprobs=2,
disable_logprobs=False,
temperature=0.0, temperature=0.0,
ensure_all_accepted=ensure_all_accepted) ensure_all_accepted=ensure_all_accepted)

View File

@ -26,6 +26,7 @@ for the target model outputs.
import pytest import pytest
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test
"speculative_model": "[ngram]", "speculative_model": "[ngram]",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False,
}, },
{ {
"speculative_model": "[ngram]", "speculative_model": "[ngram]",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True,
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int): prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size.""" """Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0: maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,

View File

@ -60,6 +60,7 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
num_gpu_blocks = 2048 // block_size num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size, scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed) num_gpu_blocks, seed)
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\ scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True should_modify_greedy_probs_inplace = True

View File

@ -754,6 +754,7 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids, accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs, target_logprobs=target_token_logprobs,
prompt_logprobs=None,
k=k, k=k,
stage_times=(0, 0, 0)) stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following: # Verify that _seq_with_bonus_token_in_last_step contains the following:

View File

@ -274,3 +274,15 @@ def create_batch(batch_size,
prompts, num_gpu_blocks, block_size, final_prompt_lens, prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids) prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens return seq_group_metadata_list, prompts, prev_output_tokens
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
if prefill_chunk_size > 0:
llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
llm_kwargs["enable_chunked_prefill"] = False

View File

@ -1685,7 +1685,8 @@ class SpeculativeConfig:
raise ValueError("Expect the batch size threshold of disabling " raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}") f"{speculative_disable_by_batch_size=}")
if (enable_chunked_prefill and speculative_model == "eagle"):
raise ValueError("Chunked prefill and EAGLE are not compatible.")
# TODO: The user should be able to specify revision/max model len # TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported. # for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
@ -1752,12 +1753,6 @@ class SpeculativeConfig:
f"num_speculative_tokens={n_predict}, but " f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.") f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config, target_parallel_config,

View File

@ -1010,8 +1010,23 @@ class LLMEngine:
self.speculative_config self.speculative_config
# Organize outputs by [step][sequence group] instead of # Organize outputs by [step][sequence group] instead of
# [sequence group][step]. # [sequence group][step].
if self.scheduler_config.is_multi_step:
outputs_by_sequence_group = create_output_by_sequence_group( outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list)) outputs, len(seq_group_metadata_list))
elif self.speculative_config:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills = sum(sg.is_prompt
for sg in seq_group_metadata_list)
prefills, decodes = outputs[:num_prefills], outputs[
num_prefills:]
outputs_by_sequence_group = create_output_by_sequence_group(
decodes,
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
outputs_by_sequence_group = [p.outputs for p in prefills
] + outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst, # We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output. # so invalidate is_first_step_output.
is_first_step_output = None is_first_step_output = None

View File

@ -83,13 +83,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if not non_spec_indices: if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled # All sequence groups in batch have spec decoding enabled
contracted = self._contract_batch_all_spec( return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
) )
else: else:
# Batch has a mix of spec decode enabled and disabled seq groups # Batch has a mix of spec decode enabled and disabled seq groups
contracted = self._contract_batch( return self._contract_batch(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
@ -99,14 +99,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
k=execute_model_req.num_lookahead_slots, k=execute_model_req.num_lookahead_slots,
) )
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=all_hidden_states,
)
def _expand_batch( def _expand_batch(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@ -143,13 +135,57 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return (spec_indices, non_spec_indices, target_seq_group_metadata_list, return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) num_scoring_tokens)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
def _contract_batch( def _contract_batch(
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata], self,
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
num_scoring_tokens: int, non_spec_indices: List[int], target_sampler_output: SamplerOutput,
spec_indices: List[int], k: int proposals: SpeculativeProposals, num_scoring_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, non_spec_indices: List[int], spec_indices: List[int],
Optional[torch.Tensor]]: k: int) -> SpeculativeScores:
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
@ -195,23 +231,28 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else: else:
all_hidden_states = None all_hidden_states = None
# Rule out prefills that produce no tokens. has_prompt_log = any((sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
for sg in contracted_seq_group_metadata_list)
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is.
prompt_logprobs = None
if (not self._scorer_worker.model_runner.disable_logprobs\
and has_prompt_log):
prompt_logprobs = [
o.prompt_logprobs for o in target_sampler_output.outputs
]
elif not has_prompt_log:
# When prompt logprobs are not to be returned,
# we can ignore non-terminal chunks (no out token).
non_spec_indices = [ non_spec_indices = [
idx for idx in non_spec_indices idx for idx in non_spec_indices
if contracted_seq_group_metadata_list[idx].do_sample if contracted_seq_group_metadata_list[idx].do_sample
] ]
if len(non_spec_indices):
all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None:
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
# "Contract" speculative.
if spec_indices: if spec_indices:
all_tokens[spec_indices] = target_token_ids all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs all_probs[spec_indices] = target_probs
@ -219,14 +260,27 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if all_hidden_states is not None: if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states spec_scores = SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs)
non_spec_outputs = SpeculativeScores(
probs=non_spec_target_probs,
token_ids=non_spec_target_token_ids,
logprobs=non_spec_target_logprobs,
hidden_states=non_spec_target_hidden_states)
# Contract remaining nonspec entries based on non_spec_indices, if any.
return self._contract_non_speculative(
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
non_spec_outputs, has_prompt_log)
def _contract_batch_all_spec( def _contract_batch_all_spec(
self, self,
target_sampler_output: SamplerOutput, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, ) -> SpeculativeScores:
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
@ -250,8 +304,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_hidden_states = target_hidden_states.reshape( target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1]) *target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs, return SpeculativeScores(probs=target_probs,
target_hidden_states) token_ids=target_token_ids,
logprobs=target_logprobs,
hidden_states=target_hidden_states,
prompt_logprobs=None)
def _create_scoring_model_input( def _create_scoring_model_input(
self, self,

View File

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Set, Union from typing import List, Optional, Set, Union
import torch import torch
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest, PromptLogprobs
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
@ -54,6 +54,10 @@ class SpeculativeScores:
# Optional last hidden states from the scoring model. # Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
prompt_logprobs: Optional[List[PromptLogprobs]] = None
def __repr__(self): def __repr__(self):
return (f"SpeculativeScores(" return (f"SpeculativeScores("
f"probs={self.probs.shape}, " f"probs={self.probs.shape}, "

View File

@ -72,9 +72,15 @@ class MQAScorer(SpeculativeScorer):
target_token_ids = target_sampler_output.sampled_token_ids target_token_ids = target_sampler_output.sampled_token_ids
target_probs = target_sampler_output.sampled_token_probs target_probs = target_sampler_output.sampled_token_probs
target_logprobs = target_sampler_output.logprobs target_logprobs = target_sampler_output.logprobs
prompt_logprobs = None
# If all requests have the same number of query tokens, we can avoid # If all requests have the same number of query tokens, we can avoid
# the for loop to build output for better performance. # the for loop to build output for better performance.
if min(all_proposal_lengths) == k: if min(all_proposal_lengths) == k:
# Regular decodes only.
assert all(not sg.is_prompt
for sg in target_seq_group_metadata_list
if sg.is_prompt)
bs, _ = proposals.proposal_token_ids.shape bs, _ = proposals.proposal_token_ids.shape
all_tokens = target_token_ids.reshape(bs, k + 1) all_tokens = target_token_ids.reshape(bs, k + 1)
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
@ -88,11 +94,47 @@ class MQAScorer(SpeculativeScorer):
all_logprobs = target_logprobs.new_full(size=all_probs.shape, all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf")) fill_value=-float("inf"))
target_token_ids = target_token_ids.flatten() target_token_ids = target_token_ids.flatten()
start_loc = 0
for i, (proposed_len, seq_meta) in enumerate( # When prompt logprobs is enabled, lens of returned tensors go from
zip(all_proposal_lengths, target_seq_group_metadata_list)): # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is, since it may be
# that n_prompts >> K.
has_prompt_log = any((sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
for sg in target_seq_group_metadata_list)
# TODO (NickLucche) we should surface `disable_logprobs` as to not
# break abstraction to get its value.
if (not self._scorer_worker.model_runner.disable_logprobs\
and has_prompt_log):
prompt_logprobs = [
o.prompt_logprobs for o in target_sampler_output.outputs
]
# Split loop into prefill|decode for readability.
start_loc, i = 0, 0
while i < len(target_seq_group_metadata_list
) and target_seq_group_metadata_list[i].is_prompt:
seq_meta = target_seq_group_metadata_list[i]
end_loc = start_loc
if has_prompt_log:
end_loc += seq_meta.token_chunk_size
elif seq_meta.do_sample:
end_loc += 1
# Skip chunks with no output tokens. # Skip chunks with no output tokens.
if seq_meta.do_sample: if seq_meta.do_sample:
# Get sampled token (last position in chunk) and its prob.
all_tokens[i, 0] = target_token_ids[end_loc - 1]
all_probs[i, 0] = target_probs[end_loc - 1]
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
i += 1
start_loc = end_loc
# Decodes.
while i < len(target_seq_group_metadata_list):
proposed_len, seq_meta = all_proposal_lengths[
i], target_seq_group_metadata_list[i]
output_len = proposed_len + 1 output_len = proposed_len + 1
end_loc = start_loc + output_len end_loc = start_loc + output_len
all_tokens[ all_tokens[
@ -101,6 +143,7 @@ class MQAScorer(SpeculativeScorer):
all_logprobs[ all_logprobs[
i, :output_len] = target_logprobs[start_loc:end_loc] i, :output_len] = target_logprobs[start_loc:end_loc]
start_loc = end_loc start_loc = end_loc
i += 1
hidden_states = None hidden_states = None
if target_sampler_output.hidden_states is not None: if target_sampler_output.hidden_states is not None:
@ -110,4 +153,5 @@ class MQAScorer(SpeculativeScorer):
return SpeculativeScores(probs=all_probs, return SpeculativeScores(probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
logprobs=all_logprobs, logprobs=all_logprobs,
hidden_states=hidden_states) hidden_states=hidden_states,
prompt_logprobs=prompt_logprobs)

View File

@ -563,27 +563,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
(seq_id, seq_data) for sg in \ (seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \ execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items() for seq_id, seq_data in sg.seq_data.items()
if sg.do_sample # ignore empty token sequences
] ]
completion_seq_group_output_list: List[ completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = [] CompletionSequenceGroupOutput] = []
output_index = 0 output_index = 0
# Make sure the non-terminal prefill chunks are still aligned with # Make sure the non-terminal prefill chunks are still aligned with
# their own empty output. # their own empty output.
for seq_group_meta in execute_model_req.seq_group_metadata_list: for idx, seq_group_meta in enumerate(
# Since we can get chunks here, we dont always have a sampled token execute_model_req.seq_group_metadata_list):
# (only on last chunk) but we still have to provide an output. needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
if not seq_group_meta.do_sample: seq_id, seq_data = seq_data_entries[idx]
completion_seq_group_output_list.append(
CompletionSequenceGroupOutput(samples=[],
prompt_logprobs=None))
else:
# Sequence with output.
seq_id, seq_data = seq_data_entries[output_index]
needs_prompt_logprobs = seq_output_prompt_logprobs[
output_index]
if needs_prompt_logprobs: if needs_prompt_logprobs:
prompt_token_ids = seq_data.get_prompt_token_ids() prompt_token_ids = seq_data.get_prompt_token_ids()
# Some of these sequences may belong to non-terminal chunks,
# which may still have to report logprobs for prompts.
start = 1 if seq_data._num_computed_tokens == 0 \
else seq_data._num_computed_tokens
end = (seq_data._num_computed_tokens + \
seq_group_meta.token_chunk_size)
prompt_token_ids = prompt_token_ids[start:end]
prompt_logprobs = [ prompt_logprobs = [
create_logprobs_output( create_logprobs_output(
token_id=p_token_id, token_id=p_token_id,
@ -591,12 +590,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
token_id_logprob=0.0, token_id_logprob=0.0,
topk_token_ids=[], topk_token_ids=[],
topk_logprobs=[], topk_logprobs=[],
) ) for p_token_id in prompt_token_ids
# no prompt logprobs for the first token
for p_token_id in prompt_token_ids[1:]
] ]
else: else:
prompt_logprobs = None prompt_logprobs = None
# Since we can get chunks here, we dont always have a sampled token
# (only on last chunk) but we still have to provide an output.
if not seq_group_meta.do_sample:
completion_seq_group_output_list.append(
CompletionSequenceGroupOutput(
samples=[], prompt_logprobs=prompt_logprobs))
continue
# Sequence with output.
completion_seq_group_output_list.append( completion_seq_group_output_list.append(
create_sequence_group_output( create_sequence_group_output(
token_id=sampled_token_ids_list[output_index][0], token_id=sampled_token_ids_list[output_index][0],
@ -624,24 +631,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
assert len(sampler_output) == 1 assert len(sampler_output) == 1
sampler_output = sampler_output[0] sampler_output = sampler_output[0]
# Store hidden states from target model execution. # Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states hidden_states = sampler_output.hidden_states
if hidden_states is not None: if hidden_states is not None:
# remove hidden_states for prompt tokens # Only decodes and prefill terminal chunks need a hidden state.
# TODO Enable `return_hidden_states`: prefill chunks hidden states seq_group_meta_with_hidden = [
# are pruned by the logits processor. Also, they should be arranged sg for sg in execute_model_req.seq_group_metadata_list
# back into full-prefill latent. Address it to enable MLPSpeculator. if sg.do_sample
if any(seq.is_prompt ]
for seq in execute_model_req.seq_group_metadata_list): if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[ hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids - torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]] VLLM_INVALID_TOKEN_ID)[0]]
if self.previous_hidden_states is None: if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, seq_group_meta_with_hidden)
else: elif self.previous_hidden_states and len(
self.previous_hidden_states.update( seq_group_meta_with_hidden):
hidden_states, execute_model_req.seq_group_metadata_list) self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
if not skip_proposer: if not skip_proposer:
# We prepare the prefill hidden states here so that there no # We prepare the prefill hidden states here so that there no
@ -752,13 +762,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
] ]
if len(non_spec_indices): if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states all_hidden_states = proposal_scores.hidden_states
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
if all_hidden_states is not None: if all_hidden_states is not None:
prefill_hidden_states = all_hidden_states[non_spec_indices] prefill_hidden_states = all_hidden_states[non_spec_indices]
execute_model_req.previous_hidden_states = \ execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(prefill_hidden_states) prepare_prefill_hidden_states(prefill_hidden_states)
# Sync proposer KV cache for prefills. # Sync proposer KV cache for prefills.
prefill_req = execute_model_req.clone(non_spec_seqs) prefill_req = execute_model_req.clone(non_spec_seqs)
# TODO avoid sampling here?
self.proposer_worker.execute_model(prefill_req) self.proposer_worker.execute_model(prefill_req)
with Timer() as verification_timer: with Timer() as verification_timer:
@ -774,6 +784,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
accepted_token_ids, accepted_token_ids,
target_logprobs=target_logprobs, target_logprobs=target_logprobs,
prompt_logprobs=proposal_scores.prompt_logprobs
if not self._disable_logprobs else None,
k=execute_model_req.num_lookahead_slots, k=execute_model_req.num_lookahead_slots,
stage_times=stage_times) stage_times=stage_times)
@ -845,19 +857,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# metadata. # metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone() accepted_token_ids[original_indices] = accepted_token_ids.clone()
# B x K+1 x D
hidden_states = proposal_scores.hidden_states hidden_states = proposal_scores.hidden_states
if hidden_states is not None: if hidden_states is not None:
# Only get terminal hidden states for next step
terminal_metadata = [
sg for sg in seq_group_metadata_list if sg.do_sample
]
# Contract hidden states based on accepted tokens # Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1] hs_size = hidden_states.shape[-1]
accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
index = accepted_index[:, None, None].expand(-1, 1, hs_size) # Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[
accepted_index != VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[
accepted_index != VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,
hs_size) # b x 1 x d
second_last_token_hidden_states = hidden_states[:, -2] # b x d second_last_token_hidden_states = hidden_states[:, -2] # b x d
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step # Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_metadata_list, hidden_states, terminal_metadata,
second_last_token_hidden_states) second_last_token_hidden_states)
return accepted_token_ids, logprobs return accepted_token_ids, logprobs
@ -866,6 +891,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
prompt_logprobs: Optional[
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
k: int, k: int,
stage_times: Tuple[float, float, float], stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
@ -909,15 +936,89 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Construct the output on a per-step, per-sequence basis. # Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s # Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list: List[SamplerOutput] = [] sampler_output_list: List[SamplerOutput] = []
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for i, sg in enumerate(seq_group_metadata_list):
if not sg.is_prompt:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs = num_logprobs_per_seq[i]
seq_kwargs = dict(token_id=-1,
token_id_logprob_rank=0,
token_id_logprob=-float('inf'),
topk_token_ids=[-1] * num_logprobs,
topk_logprobs=[-float('inf')] * num_logprobs,
seq_id=seq_ids[i])
# Terminal chunk, has token.
if sg.do_sample:
seq_kwargs.update(
dict(
token_id=accepted_token_ids[i][0].item(),
token_id_logprob_rank=accepted_token_id_ranks_by_step[
0][i],
token_id_logprob=accepted_token_id_logprobs_by_step[0]
[i],
topk_token_ids=topk_indices_by_step[0][i]
[:num_logprobs],
# output only so step is 0
topk_logprobs=topk_logprobs_by_step[0][i]
[:num_logprobs],
))
needs_plogs = (sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
plogs = None
if prompt_logprobs is not None:
# Even non-terminal prompt chunks can have logprobs here.
plogs = prompt_logprobs[i]
elif needs_plogs:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data = next(iter(sg.seq_data.values()))
# Get only the tokens in this chunk!
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_token_ids = prompt_token_ids[
seq_data.
_num_computed_tokens:seq_data._num_computed_tokens +
sg.token_chunk_size]
is_first_chunk = seq_data._num_computed_tokens == 0
# There's no prob generated for the first token in a sequence.
if is_first_chunk:
prompt_token_ids = prompt_token_ids[1:]
plogs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
) for p_token_id in prompt_token_ids
]
seq_kwargs.update(dict(prompt_logprobs=plogs))
sampler_output_list.append(
SamplerOutput(
outputs=[create_sequence_group_output(
**seq_kwargs)])) # type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for step_index in range(num_steps): for step_index in range(num_steps):
if all(token_id == -1 if all(token_id == -1 for sg, token_id in zip(
for token_id in accepted_token_ids_by_step[step_index]): seq_group_metadata_list,
accepted_token_ids_by_step[step_index])
if not sg.is_prompt):
break break
step_output_token_ids: List[CompletionSequenceGroupOutput] = [] step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size): for sequence_index in range(batch_size):
seq_meta = seq_group_metadata_list[sequence_index]
# Prompts already processed above.
if seq_meta.is_prompt:
continue
# Each sequence may have a different num_logprobs; retrieve it. # Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index] num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append( step_output_token_ids.append(
@ -952,6 +1053,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is periodic because the rejection sampler emits metrics # This is periodic because the rejection sampler emits metrics
# periodically. # periodically.
self._maybe_log_stage_times(*stage_times) self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return sampler_output_list return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,