[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2024-11-07 17:15:14 +01:00 committed by GitHub
parent ae62fd17c0
commit 9d43afcc53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 476 additions and 146 deletions

View File

@ -5,40 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{ @pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf", "model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",

View File

@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
{ {
# Verify the detokenizer assertions in the test work when spec # Verify the detokenizer assertions in the test work when spec
@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize("max_output_len", [ @pytest.mark.parametrize("max_output_len", [
@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("batch_size", [1])
@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("batch_size", [32])
@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32, "speculative_max_model_len": 32,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
}, },
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2, "speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}, },
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"enable_chunked_prefill": False,
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
]) ] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler" "spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": False
} }
# Try a range of common k. # Try a range of common k.
for k in [1, 2, 3] for k in [1, 2, 3]
]) ] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",

View File

@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "ngram_prompt_lookup_max": 3,
}, },
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
256, 256,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, batch_size: int, output_len: int,
seed: int): prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size.""" """Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model": "[ngram]", "speculative_model": "[ngram]",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4 "speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}]) }])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_prompt_lens=final_prompt_lens) final_prompt_lens=final_prompt_lens)
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def test_ngram_algo_correctness_for_batches_match_all(): def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt """Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs For the scenario find candidate in all batches
""" """
block_size = 32 block_size = 32
@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size, block_size,
final_prompt_lens=final_prompt_lens) final_prompt_lens=final_prompt_lens)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,

View File

@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@pytest.mark.parametrize('max_propose_len', [1, 3, 5]) @pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True]) @pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda']) @pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('prefill_chunking', [False, True])
def test_scorer(model_name: str, batch_size: int, max_propose_len: int, def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None: mixed_propose_len: bool, device: str,
prefill_chunking: bool) -> None:
""" """
Compare the batch expansion scorer and mqa scorer return the same score. Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different We test for both queries with the same propose length and different
propose length. propose length, as well as mixed prefill-decode batches.
""" """
seed = 0 seed = 0
block_size = 32 block_size = 32
@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if not mixed_propose_len: if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size propose_lens = [max_propose_len] * batch_size
else: else:
non_zero_cnt = random.randint(0, batch_size) # There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt = random.randint(1, batch_size)
propose_lens = [max_propose_len propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt) ] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens) random.shuffle(propose_lens)
proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size, seq_group_metadatalist, _, _ = create_batch(batch_size,
max_propose_len, max_propose_len,
block_size=block_size, block_size=block_size,
num_gpu_blocks=num_gpu_blocks) num_gpu_blocks=num_gpu_blocks)
if mixed_propose_len and prefill_chunking and (n_prefills :=
batch_size - non_zero_cnt):
prefill, _, _ = create_batch(n_prefills,
None,
prefill_chunk_size=4,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
seq_ids=list(
range(batch_size,
batch_size + n_prefills)))
# re-order to guarantee prefill|decode order
target_group_metadatalist = [
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
if p > 0
]
seq_group_metadatalist = prefill + target_group_metadatalist
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
proposals = create_proposal(propose_lens, vocab_size, device)
requests = ExecuteModelRequest(seq_group_metadatalist, requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=max_propose_len) num_lookahead_slots=max_propose_len)

View File

@ -10,6 +10,7 @@ import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
@ -819,3 +820,84 @@ def test_handle_finished_requests():
# and 'request-3' are removed from seq_with_bonus_token_in_last_step. # and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \ assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10} {4,5,10}
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('batch_size', [2, 32])
@pytest.mark.parametrize("batch_composition",
["prefill_only", "decode_only", "mixed"])
@torch.inference_mode()
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes, _, _ = create_batch(batch_size, k)
# Pre-chunking here, get 'batch_size' chunks.
prefill, _, _ = create_batch(batch_size,
k,
prefill_chunk_size=4,
seq_ids=list(range(batch_size,
batch_size * 2)))
if batch_composition == "prefill_only":
n_prefills = batch_size
elif batch_composition == "decode_only":
n_prefills = 0
else:
n_prefills = random.randint(1, batch_size - 1)
n_decodes = batch_size - n_prefills
prefill = random.sample(prefill, n_prefills)
decodes = random.sample(decodes, n_decodes)
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
num_lookahead_slots=k)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
if not len(decodes):
worker.execute_model(execute_model_req=execute_model_req)
# no spec run (prefill only)
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
else:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1

View File

@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
return seq_grou_metadata_list return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose( def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]], actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None: expected_logprobs: List[Dict[int, Logprob]]) -> None:
@ -198,7 +233,8 @@ def create_batch(batch_size,
prev_output_token_len: int = 10, prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[List[int]] = None,
num_gpu_blocks: Optional[int] = None, num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None): block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None: if block_size is None:
block_size = 8 block_size = 8
@ -213,15 +249,28 @@ def create_batch(batch_size,
prompt_lens = prompt_len prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts( if prefill_chunk_size:
prompts, num_gpu_blocks, block_size, final_prompt_lens, # Create a batch of chunked prompts.
prev_output_tokens, seq_ids) if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens return seq_group_metadata_list, prompts, prev_output_tokens

View File

@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=self.query_start_loc[self.num_prefills:] # Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None, if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:] seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None, if self.seq_start_loc is not None else None,
@ -903,7 +907,9 @@ def unified_flash_attention(
# Decoding run. # Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding # Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths. # because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1: if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1") "Only decoder-only models support max_decode_query_len > 1")
@ -949,8 +955,6 @@ def unified_flash_attention(
assert prefill_output is not None assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size) return prefill_output.view(num_prefill_query_tokens, hidden_size)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None assert decode_meta is not None
decode_output = decode_output.squeeze(1) decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)

View File

@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
) )
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if self._cached_decode_metadata.query_start_loc is not None:
qs = self._cached_decode_metadata.query_start_loc
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, def advance_step(self,

View File

@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables) cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if self._cached_decode_metadata.query_start_loc is not None:
qs = self._cached_decode_metadata.query_start_loc
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata return self._cached_decode_metadata

View File

@ -192,7 +192,6 @@ class ModelConfig:
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta, code_revision, rope_scaling, rope_theta,
config_format) config_format)
@ -1317,13 +1316,6 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}") f"{speculative_disable_by_batch_size=}")
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")
# TODO: The user should be able to specify revision/max model len # TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported. # for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
@ -1390,6 +1382,12 @@ class SpeculativeConfig:
f"num_speculative_tokens={n_predict}, but " f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.") f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")
draft_model_config.max_model_len = ( draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len( SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len, speculative_max_model_len,

View File

@ -1147,6 +1147,7 @@ class Scheduler:
# Update swapped requests. # Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
# Put prefills first due to Attention backend ordering assumption.
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups + running_scheduled.prefill_seq_groups +

View File

@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sample for sample in samples sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID if sample.output_token != VLLM_INVALID_TOKEN_ID
] ]
assert valid_samples
self._process_seq_outputs(seq, valid_samples, # When both spec-decode and pre-fill chunking are enabled, we
sequence_group.sampling_params) # don't have guaranteed samples here (e.g. all -1s).
if valid_samples:
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence, def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:

View File

@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else: else:
# Batch has a mix of spec decode enabled and disabled seq groups # Batch has a mix of spec decode enabled and disabled seq groups
contracted = self._contract_batch( contracted = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list), execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, num_scoring_tokens=num_scoring_tokens,
@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
split_batch_by_proposal_len( split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list) seq_group_metadata_list, proposal_lens_list)
target_seq_group_metadata_list = self._create_scoring_model_input( spec_expanded_seqs = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs, seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list, proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the # NOTE: We determine the seq ids in the expanded batch using the
@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
seq_ids=get_all_seq_ids(seq_group_metadata_list)), seq_ids=get_all_seq_ids(seq_group_metadata_list)),
) )
num_scoring_tokens = len(target_seq_group_metadata_list) num_scoring_tokens = len(spec_expanded_seqs)
target_seq_group_metadata_list.extend(non_spec_seqs) # Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
return (spec_indices, non_spec_indices, target_seq_group_metadata_list, return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) num_scoring_tokens)
def _contract_batch( def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput, self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
proposals: SpeculativeProposals, num_scoring_tokens: int, target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
non_spec_indices: List[int], spec_indices: List[int], k: int num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to. target_sampler_output will be contracted to.
""" """
contracted_bs = len(contracted_seq_group_metadata_list)
(target_token_ids, target_probs, target_logprobs, target_hidden_states, (target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs, non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs, non_spec_target_logprobs,
@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is # The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for # equal to the total expanded batch size minus the number of samples for
# non-speculative sequences. # non-speculative sequences, prefill chunks with no out tokens included
non_spec_expanded_bs = len(non_spec_target_token_ids) non_spec_expanded_bs = len(non_spec_indices)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else: else:
all_hidden_states = None all_hidden_states = None
if non_spec_indices: # Rule out prefills that produce no tokens.
non_spec_indices = [
idx for idx in non_spec_indices
if contracted_seq_group_metadata_list[idx].do_sample
]
if len(non_spec_indices):
all_tokens[non_spec_indices, :1] = \ all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1) non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \ all_probs[non_spec_indices, :1, :] = \
@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This function creates K+1 target SequenceGroupMetadata to take This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token. advantage of the bonus token.
""" """
assert not input_seq_group_metadata.is_prompt, (
"Speculating on "
"prompts not yet supported")
assert len(input_seq_group_metadata.seq_data) == 1, ( assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search " "Beam search "
"not supported in speculative decoding") "not supported in speculative decoding")
@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# and non spec sequences) and should be removed in the future. It can be # and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens. # done by supporting per-sequence proposal lens.
# #
# First samples are from speculative scoring, latter samples are non- # First samples are non-speculative, latter samples are from speculative
# speculative samples. # scoring (prefill|decode order).
split_sizes = (num_scoring_tokens, split_sizes = (sampler_output.sampled_token_ids.numel() -
sampler_output.sampled_token_ids.numel() - num_scoring_tokens, num_scoring_tokens)
num_scoring_tokens) (non_spec_probs,
(spec_probs, non_spec_probs spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
) = sampler_output.sampled_token_probs.split(split_sizes) (non_spec_sampled_tokens, spec_sampled_tokens
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes) ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
( (non_spec_logprobs,
spec_logprobs, spec_logprobs) = sampler_output.logprobs.split(split_sizes)
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
if sampler_output.hidden_states is not None: if sampler_output.hidden_states is not None:
( (non_spec_hidden_states, spec_hidden_states
spec_hidden_states, ) = sampler_output.hidden_states.split(split_sizes)
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else: else:
spec_hidden_states, non_spec_hidden_states = None, None non_spec_hidden_states, spec_hidden_states = None, None
return (spec_sampled_tokens, spec_probs, spec_logprobs, return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,

View File

@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
all_proposal_lengths = proposals.proposal_lens.tolist() all_proposal_lengths = proposals.proposal_lens.tolist()
for i, seq_group_metadata in enumerate( for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list): execute_model_req.seq_group_metadata_list):
if all_proposal_lengths[i] == 0:
# Keep prompt seqs untouched (keep computed_tokens for chunks).
target_seq_group_metadata_list.append(seq_group_metadata)
continue
seq_data_dict = seq_group_metadata.seq_data seq_data_dict = seq_group_metadata.seq_data
assert len(seq_data_dict) == 1 assert len(seq_data_dict) == 1
seq_id = next(iter(seq_data_dict.keys())) seq_id = next(iter(seq_data_dict.keys()))
@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
new_seq_data.update_num_computed_tokens( new_seq_data.update_num_computed_tokens(
len(prompt_token_ids) + len(output_token_ids) - 1) len(prompt_token_ids) + len(output_token_ids) - 1)
# Ensure that the new sequence has at least one token # Ensure that the new decode sequence has at least one token.
# because we only use mqa scorer in the decoding stage.
assert len(output_token_ids) >= 1 assert len(output_token_ids) >= 1
new_seq_data_dict = {target_seq_id: new_seq_data} new_seq_data_dict = {target_seq_id: new_seq_data}
@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
target_seq_id: seq_group_metadata.block_tables[seq_id], target_seq_id: seq_group_metadata.block_tables[seq_id],
}, },
lora_request=None, lora_request=None,
token_chunk_size=1,
) )
target_seq_group_metadata_list.append(new_seq_group_metadata) target_seq_group_metadata_list.append(new_seq_group_metadata)
@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
else: else:
# We either have decodes with different lens or prefill+decodes.
all_tokens = target_token_ids.new_full(size=(bs, k + 1), all_tokens = target_token_ids.new_full(size=(bs, k + 1),
fill_value=-1) fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, all_probs = target_probs.new_zeros(*all_tokens.shape,
@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
fill_value=-float("inf")) fill_value=-float("inf"))
target_token_ids = target_token_ids.flatten() target_token_ids = target_token_ids.flatten()
start_loc = 0 start_loc = 0
for i, proposed_len in enumerate(all_proposal_lengths): for i, (proposed_len, seq_meta) in enumerate(
output_len = proposed_len + 1 zip(all_proposal_lengths, target_seq_group_metadata_list)):
end_loc = start_loc + output_len # Skip chunks with no output tokens.
all_tokens[ if seq_meta.do_sample:
i, :output_len] = target_token_ids[start_loc:end_loc] output_len = proposed_len + 1
all_probs[i, :output_len] = target_probs[start_loc:end_loc] end_loc = start_loc + output_len
all_logprobs[ all_tokens[
i, :output_len] = target_logprobs[start_loc:end_loc] i, :output_len] = target_token_ids[start_loc:end_loc]
start_loc = end_loc all_probs[i, :output_len] = target_probs[start_loc:end_loc]
all_logprobs[
i, :output_len] = target_logprobs[start_loc:end_loc]
start_loc = end_loc
hidden_states = None hidden_states = None
if target_sampler_output.hidden_states is not None: if target_sampler_output.hidden_states is not None:

View File

@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# none of the requests in the batch have spec decoding enabled. # none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers # In any of these cases, the proposer and scorer workers
# are called normally. # are called normally.
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( # We expect `num_speculative_tokens` to be None for prefills.
no_spec = all(
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
) or num_lookahead_slots == 0 or disable_all_speculation or all(
sgm.num_speculative_tokens == 0 sgm.num_speculative_tokens == 0
for sgm in execute_model_req.seq_group_metadata_list) for sgm in execute_model_req.seq_group_metadata_list)
@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def _serialize_sampler_output_no_logprobs( def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest, self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput: sampler_output: SamplerOutput) -> List[SamplerOutput]:
""" """
Creates and returns a `SamplerOutput` with only the token IDs being Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`. serialized to CPU and populated in `CompletionSequenceGroupOutput`.
@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if any(seq_output_prompt_logprobs) else \ if any(seq_output_prompt_logprobs) else \
sampler_output.sampled_token_ids).tolist() sampler_output.sampled_token_ids).tolist()
seq_data_entries = ( seq_data_entries = [
(seq_id, seq_data) for sg in \ (seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \ execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items() for seq_id, seq_data in sg.seq_data.items()
) if sg.do_sample # ignore empty token sequences
]
completion_seq_group_output_list: List[ completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = [] CompletionSequenceGroupOutput] = []
for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ output_index = 0
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): # Make sure the non-terminal prefill chunks are still aligned with
if needs_prompt_logprobs: # their own empty output.
prompt_token_ids = seq_data.get_prompt_token_ids() for seq_group_meta in execute_model_req.seq_group_metadata_list:
prompt_logprobs = [ # Since we can get chunks here, we dont always have a sampled token
create_logprobs_output( # (only on last chunk) but we still have to provide an output.
token_id=p_token_id, if not seq_group_meta.do_sample:
completion_seq_group_output_list.append(
CompletionSequenceGroupOutput(samples=[],
prompt_logprobs=None))
else:
# Sequence with output.
seq_id, seq_data = seq_data_entries[output_index]
needs_prompt_logprobs = seq_output_prompt_logprobs[
output_index]
if needs_prompt_logprobs:
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_logprobs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
)
# no prompt logprobs for the first token
for p_token_id in prompt_token_ids[1:]
]
else:
prompt_logprobs = None
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[output_index][0],
token_id_logprob_rank=-1, token_id_logprob_rank=-1,
token_id_logprob=0.0, token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[], topk_token_ids=[],
topk_logprobs=[], topk_logprobs=[],
) prompt_logprobs=prompt_logprobs))
# no prompt logprobs for the first token output_index += 1
for p_token_id in prompt_token_ids[1:]
]
else:
prompt_logprobs = None
completion_seq_group_output_list.append( return [SamplerOutput(outputs=completion_seq_group_output_list)]
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
prompt_logprobs=prompt_logprobs))
return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states = sampler_output.hidden_states hidden_states = sampler_output.hidden_states
if hidden_states is not None: if hidden_states is not None:
# remove hidden_states for prompt tokens # remove hidden_states for prompt tokens
# TODO Enable `return_hidden_states`: prefill chunks hidden states
# are pruned by the logits processor. Also, they should be arranged
# back into full-prefill latent. Address it to enable MLPSpeculator.
if any(seq.is_prompt if any(seq.is_prompt
for seq in execute_model_req.seq_group_metadata_list): for seq in execute_model_req.seq_group_metadata_list):
hidden_states = hidden_states[ hidden_states = hidden_states[
@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output) execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else if self._disable_logprobs else
sampler_output) [sampler_output])
# Clear device tensors from sampler output. This reduces communication # Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None sampler_output.sampled_token_ids = None
sampler_output.logprobs = None sampler_output.logprobs = None
return [sampler_output_to_return] return sampler_output_to_return
def _run_non_driver_rank(self) -> bool: def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used """Run proposer and verifier model in non-driver workers. This is used
@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
This invokes the proposer worker to get k speculative tokens for each This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker. sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per Returns a list of SamplerOutput, each containing a single token per
sequence. sequence.
""" """
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert num_lookahead_slots == execute_model_req.num_lookahead_slots assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Pass last hidden states from target model to proposer # Pass last hidden states from target model to proposer
@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals, proposals,
) )
_, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices = [
idx for idx in non_spec_indices
if execute_model_req.seq_group_metadata_list[idx].is_prompt
]
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
if all_hidden_states is not None:
prefill_hidden_states = all_hidden_states[non_spec_indices]
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(prefill_hidden_states)
# Sync proposer KV cache for prefills.
prefill_req = execute_model_req.clone(non_spec_seqs)
self.proposer_worker.execute_model(prefill_req)
with Timer() as verification_timer: with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens( accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores, execute_model_req.seq_group_metadata_list, proposal_scores,
@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_metadata_list, hidden_states, seq_group_metadata_list,
second_last_token_hidden_states) second_last_token_hidden_states)
return accepted_token_ids, logprobs return accepted_token_ids, logprobs
def _create_output_sampler_list( def _create_output_sampler_list(
@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
# Construct the output on a per-step, per-sequence basis. # Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
sampler_output_list: List[SamplerOutput] = [] sampler_output_list: List[SamplerOutput] = []
for step_index in range(num_steps): for step_index in range(num_steps):
if all(token_id == -1 if all(token_id == -1
@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is periodic because the rejection sampler emits metrics # This is periodic because the rejection sampler emits metrics
# periodically. # periodically.
self._maybe_log_stage_times(*stage_times) self._maybe_log_stage_times(*stage_times)
return sampler_output_list return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,

View File

@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
proposal_probs=proposal_probs, proposal_probs=proposal_probs,
proposal_lens=proposal_lens, proposal_lens=proposal_lens,
no_proposals=maybe_sampler_output is None) no_proposals=maybe_sampler_output is None)
return proposals return proposals
def _split_by_proposal_len( def _split_by_proposal_len(
@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = [] nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
# The speculative decoding for this request has been disabled # The speculative decoding for this request has either been disabled
# (e.g. due to high traffic). # (e.g. due to high traffic) or this is a prompt request.
if seq_group_metadata.num_speculative_tokens == 0: if (seq_group_metadata.is_prompt
or seq_group_metadata.num_speculative_tokens == 0):
proposal_lens.append(0) proposal_lens.append(0)
continue continue