mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ae62fd17c0
commit
9d43afcc53
@ -5,40 +5,6 @@ from vllm import SamplingParams
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"enable_chunked_prefill": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||
"""Verify that speculative decoding with chunked prefill fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding and chunked prefill"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
|
||||
@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
# Chunked prefill enabled with small value
|
||||
# to make sure we get mixed batches.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
{
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"enable_chunked_prefill": False,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
])
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"enable_chunked_prefill": False
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
])
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
} for k in [1, 2, 3]])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
if prefill_chunk_size > 0:
|
||||
common_llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
common_llm_kwargs["enable_chunked_prefill"] = False
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}, {
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batchs
|
||||
For the scenario find candidate in all batches
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Normally drafter is run on decode requests only; here we check the output
|
||||
# of the ngram worker as it is the sole proposer that has no forward.
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
|
||||
@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
|
||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
@pytest.mark.parametrize('prefill_chunking', [False, True])
|
||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
mixed_propose_len: bool, device: str) -> None:
|
||||
mixed_propose_len: bool, device: str,
|
||||
prefill_chunking: bool) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||
We test for both queries with the same propose length and different
|
||||
propose length.
|
||||
propose length, as well as mixed prefill-decode batches.
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
if not mixed_propose_len:
|
||||
propose_lens = [max_propose_len] * batch_size
|
||||
else:
|
||||
non_zero_cnt = random.randint(0, batch_size)
|
||||
# There must be at least 1 decode request, otherwise
|
||||
# we have nothing to score (`_run_no_spec`).
|
||||
non_zero_cnt = random.randint(1, batch_size)
|
||||
propose_lens = [max_propose_len
|
||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||
random.shuffle(propose_lens)
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
max_propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
if mixed_propose_len and prefill_chunking and (n_prefills :=
|
||||
batch_size - non_zero_cnt):
|
||||
prefill, _, _ = create_batch(n_prefills,
|
||||
None,
|
||||
prefill_chunk_size=4,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
seq_ids=list(
|
||||
range(batch_size,
|
||||
batch_size + n_prefills)))
|
||||
# re-order to guarantee prefill|decode order
|
||||
target_group_metadatalist = [
|
||||
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
|
||||
if p > 0
|
||||
]
|
||||
seq_group_metadatalist = prefill + target_group_metadatalist
|
||||
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=max_propose_len)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@ -819,3 +820,84 @@ def test_handle_finished_requests():
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
num_lookahead_slots=k)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: List[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, Logprob]],
|
||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||
@ -198,7 +233,8 @@ def create_batch(batch_size,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None):
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
@ -213,15 +249,28 @@ def create_batch(batch_size,
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=self.query_start_loc[self.num_prefills:]
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
@ -903,7 +907,9 @@ def unified_flash_attention(
|
||||
# Decoding run.
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
# use only for actual varlen decoding
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
@ -949,8 +955,6 @@ def unified_flash_attention(
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
assert decode_meta is not None
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
|
||||
@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
|
||||
@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
|
||||
@ -192,7 +192,6 @@ class ModelConfig:
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, rope_scaling, rope_theta,
|
||||
config_format)
|
||||
@ -1317,13 +1316,6 @@ class SpeculativeConfig:
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{speculative_disable_by_batch_size=}")
|
||||
|
||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Speculative decoding and chunked prefill are "
|
||||
f"currently mutually exclusive ({enable_chunked_prefill=}).")
|
||||
|
||||
# TODO: The user should be able to specify revision/max model len
|
||||
# for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
@ -1390,6 +1382,12 @@ class SpeculativeConfig:
|
||||
f"num_speculative_tokens={n_predict}, but "
|
||||
f"{num_speculative_tokens=} was provided.")
|
||||
|
||||
if enable_chunked_prefill and draft_hf_config.model_type in (
|
||||
"medusa", "mlp_speculator", "eagle"):
|
||||
raise ValueError(
|
||||
"Chunked prefill and hidden-state based draft models are "
|
||||
"not compatible.")
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
|
||||
@ -1147,6 +1147,7 @@ class Scheduler:
|
||||
|
||||
# Update swapped requests.
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
# Put prefills first due to Attention backend ordering assumption.
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
running_scheduled.prefill_seq_groups +
|
||||
|
||||
@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
sample for sample in samples
|
||||
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
# When both spec-decode and pre-fill chunking are enabled, we
|
||||
# don't have guaranteed samples here (e.g. all -1s).
|
||||
if valid_samples:
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_decode_and_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
|
||||
@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
contracted = self._contract_batch(
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
spec_expanded_seqs = self._create_scoring_model_input(
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
num_scoring_tokens = len(spec_expanded_seqs)
|
||||
# Batch speculative and non-speculative (e.g. chunked prefill) requests
|
||||
# but make sure order is prefill|decode due to backend requirement.
|
||||
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int], k: int
|
||||
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
spec_indices: List[int], k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
contracted_bs = len(contracted_seq_group_metadata_list)
|
||||
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs,
|
||||
@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences.
|
||||
non_spec_expanded_bs = len(non_spec_target_token_ids)
|
||||
# non-speculative sequences, prefill chunks with no out tokens included
|
||||
non_spec_expanded_bs = len(non_spec_indices)
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
if non_spec_indices:
|
||||
# Rule out prefills that produce no tokens.
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if contracted_seq_group_metadata_list[idx].do_sample
|
||||
]
|
||||
if len(non_spec_indices):
|
||||
all_tokens[non_spec_indices, :1] = \
|
||||
non_spec_target_token_ids.unsqueeze(1)
|
||||
all_probs[non_spec_indices, :1, :] = \
|
||||
@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
This function creates K+1 target SequenceGroupMetadata to take
|
||||
advantage of the bonus token.
|
||||
"""
|
||||
assert not input_seq_group_metadata.is_prompt, (
|
||||
"Speculating on "
|
||||
"prompts not yet supported")
|
||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||
"Beam search "
|
||||
"not supported in speculative decoding")
|
||||
@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
#
|
||||
# First samples are from speculative scoring, latter samples are non-
|
||||
# speculative samples.
|
||||
split_sizes = (num_scoring_tokens,
|
||||
sampler_output.sampled_token_ids.numel() -
|
||||
num_scoring_tokens)
|
||||
(spec_probs, non_spec_probs
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
# First samples are non-speculative, latter samples are from speculative
|
||||
# scoring (prefill|decode order).
|
||||
split_sizes = (sampler_output.sampled_token_ids.numel() -
|
||||
num_scoring_tokens, num_scoring_tokens)
|
||||
(non_spec_probs,
|
||||
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(non_spec_sampled_tokens, spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(
|
||||
spec_logprobs,
|
||||
non_spec_logprobs,
|
||||
) = sampler_output.logprobs.split(split_sizes)
|
||||
(non_spec_logprobs,
|
||||
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
if sampler_output.hidden_states is not None:
|
||||
(
|
||||
spec_hidden_states,
|
||||
non_spec_hidden_states,
|
||||
) = sampler_output.hidden_states.split(split_sizes)
|
||||
(non_spec_hidden_states, spec_hidden_states
|
||||
) = sampler_output.hidden_states.split(split_sizes)
|
||||
else:
|
||||
spec_hidden_states, non_spec_hidden_states = None, None
|
||||
non_spec_hidden_states, spec_hidden_states = None, None
|
||||
|
||||
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||
|
||||
@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
|
||||
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||
for i, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
if all_proposal_lengths[i] == 0:
|
||||
# Keep prompt seqs untouched (keep computed_tokens for chunks).
|
||||
target_seq_group_metadata_list.append(seq_group_metadata)
|
||||
continue
|
||||
|
||||
seq_data_dict = seq_group_metadata.seq_data
|
||||
assert len(seq_data_dict) == 1
|
||||
seq_id = next(iter(seq_data_dict.keys()))
|
||||
@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
|
||||
new_seq_data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||
|
||||
# Ensure that the new sequence has at least one token
|
||||
# because we only use mqa scorer in the decoding stage.
|
||||
# Ensure that the new decode sequence has at least one token.
|
||||
assert len(output_token_ids) >= 1
|
||||
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||
|
||||
@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||
|
||||
@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
|
||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||
else:
|
||||
# We either have decodes with different lens or prefill+decodes.
|
||||
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||
@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
|
||||
fill_value=-float("inf"))
|
||||
target_token_ids = target_token_ids.flatten()
|
||||
start_loc = 0
|
||||
for i, proposed_len in enumerate(all_proposal_lengths):
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
for i, (proposed_len, seq_meta) in enumerate(
|
||||
zip(all_proposal_lengths, target_seq_group_metadata_list)):
|
||||
# Skip chunks with no output tokens.
|
||||
if seq_meta.do_sample:
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
|
||||
@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# none of the requests in the batch have spec decoding enabled.
|
||||
# In any of these cases, the proposer and scorer workers
|
||||
# are called normally.
|
||||
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||
# We expect `num_speculative_tokens` to be None for prefills.
|
||||
no_spec = all(
|
||||
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
|
||||
) or num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||
sgm.num_speculative_tokens == 0
|
||||
for sgm in execute_model_req.seq_group_metadata_list)
|
||||
|
||||
@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def _serialize_sampler_output_no_logprobs(
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
||||
sampler_output: SamplerOutput) -> List[SamplerOutput]:
|
||||
"""
|
||||
Creates and returns a `SamplerOutput` with only the token IDs being
|
||||
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
|
||||
@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
if any(seq_output_prompt_logprobs) else \
|
||||
sampler_output.sampled_token_ids).tolist()
|
||||
|
||||
seq_data_entries = (
|
||||
seq_data_entries = [
|
||||
(seq_id, seq_data) for sg in \
|
||||
execute_model_req.seq_group_metadata_list \
|
||||
for seq_id, seq_data in sg.seq_data.items()
|
||||
)
|
||||
if sg.do_sample # ignore empty token sequences
|
||||
]
|
||||
completion_seq_group_output_list: List[
|
||||
CompletionSequenceGroupOutput] = []
|
||||
for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
|
||||
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
|
||||
if needs_prompt_logprobs:
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_logprobs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
output_index = 0
|
||||
# Make sure the non-terminal prefill chunks are still aligned with
|
||||
# their own empty output.
|
||||
for seq_group_meta in execute_model_req.seq_group_metadata_list:
|
||||
# Since we can get chunks here, we dont always have a sampled token
|
||||
# (only on last chunk) but we still have to provide an output.
|
||||
if not seq_group_meta.do_sample:
|
||||
completion_seq_group_output_list.append(
|
||||
CompletionSequenceGroupOutput(samples=[],
|
||||
prompt_logprobs=None))
|
||||
else:
|
||||
# Sequence with output.
|
||||
seq_id, seq_data = seq_data_entries[output_index]
|
||||
needs_prompt_logprobs = seq_output_prompt_logprobs[
|
||||
output_index]
|
||||
if needs_prompt_logprobs:
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_logprobs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
)
|
||||
# no prompt logprobs for the first token
|
||||
for p_token_id in prompt_token_ids[1:]
|
||||
]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[output_index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
)
|
||||
# no prompt logprobs for the first token
|
||||
for p_token_id in prompt_token_ids[1:]
|
||||
]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
prompt_logprobs=prompt_logprobs))
|
||||
output_index += 1
|
||||
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
prompt_logprobs=prompt_logprobs))
|
||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
||||
return [SamplerOutput(outputs=completion_seq_group_output_list)]
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# remove hidden_states for prompt tokens
|
||||
# TODO Enable `return_hidden_states`: prefill chunks hidden states
|
||||
# are pruned by the logits processor. Also, they should be arranged
|
||||
# back into full-prefill latent. Address it to enable MLPSpeculator.
|
||||
if any(seq.is_prompt
|
||||
for seq in execute_model_req.seq_group_metadata_list):
|
||||
hidden_states = hidden_states[
|
||||
@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
sampler_output)
|
||||
[sampler_output])
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output_to_return]
|
||||
return sampler_output_to_return
|
||||
|
||||
def _run_non_driver_rank(self) -> bool:
|
||||
"""Run proposer and verifier model in non-driver workers. This is used
|
||||
@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
sequence, then scores each speculative token using the scoring worker.
|
||||
|
||||
When `enable_chunked_prefill` is set, scorer will batch decodes and
|
||||
prefills, while proposer will sync its KV-cache by running an extra
|
||||
forward on prefills.
|
||||
|
||||
Returns a list of SamplerOutput, each containing a single token per
|
||||
sequence.
|
||||
"""
|
||||
# With prefill chunking, expect requests to have prompts first
|
||||
# so that backend gets prefill|decode.
|
||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||
|
||||
# Pass last hidden states from target model to proposer
|
||||
@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposals,
|
||||
)
|
||||
|
||||
_, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
|
||||
execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
|
||||
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
|
||||
# discard decodes that have already been processed by proposer.
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if execute_model_req.seq_group_metadata_list[idx].is_prompt
|
||||
]
|
||||
if len(non_spec_indices):
|
||||
all_hidden_states = proposal_scores.hidden_states
|
||||
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
|
||||
if all_hidden_states is not None:
|
||||
prefill_hidden_states = all_hidden_states[non_spec_indices]
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(prefill_hidden_states)
|
||||
# Sync proposer KV cache for prefills.
|
||||
prefill_req = execute_model_req.clone(non_spec_seqs)
|
||||
self.proposer_worker.execute_model(prefill_req)
|
||||
|
||||
with Timer() as verification_timer:
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_metadata_list,
|
||||
second_last_token_hidden_states)
|
||||
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
def _create_output_sampler_list(
|
||||
@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# This is periodic because the rejection sampler emits metrics
|
||||
# periodically.
|
||||
self._maybe_log_stage_times(*stage_times)
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||
|
||||
@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=maybe_sampler_output is None)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_proposal_len(
|
||||
@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
# The speculative decoding for this request has been disabled
|
||||
# (e.g. due to high traffic).
|
||||
if seq_group_metadata.num_speculative_tokens == 0:
|
||||
# The speculative decoding for this request has either been disabled
|
||||
# (e.g. due to high traffic) or this is a prompt request.
|
||||
if (seq_group_metadata.is_prompt
|
||||
or seq_group_metadata.num_speculative_tokens == 0):
|
||||
proposal_lens.append(0)
|
||||
continue
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user