mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ae62fd17c0
commit
9d43afcc53
@ -5,40 +5,6 @@ from vllm import SamplingParams
|
|||||||
from .conftest import get_output_from_llm_generator
|
from .conftest import get_output_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
|
||||||
{
|
|
||||||
"enable_chunked_prefill": True,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("seed", [1])
|
|
||||||
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
|
||||||
"""Verify that speculative decoding with chunked prefill fails.
|
|
||||||
"""
|
|
||||||
output_len = 128
|
|
||||||
temperature = 0.0
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
max_tokens=output_len,
|
|
||||||
ignore_eos=True,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError,
|
|
||||||
match="Speculative decoding and chunked prefill"):
|
|
||||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
|
||||||
sampling_params)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
|||||||
@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Chunked prefill enabled with small value
|
||||||
|
# to make sure we get mixed batches.
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Verify the detokenizer assertions in the test work when spec
|
# Verify the detokenizer assertions in the test work when spec
|
||||||
@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4,
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("max_output_len", [
|
@pytest.mark.parametrize("max_output_len", [
|
||||||
@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1])
|
@pytest.mark.parametrize("batch_size", [1])
|
||||||
@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [32])
|
@pytest.mark.parametrize("batch_size", [32])
|
||||||
@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
|||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
"speculative_max_model_len": 32,
|
"speculative_max_model_len": 32,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4,
|
||||||
|
"speculative_max_model_len": 32,
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"speculative_disable_by_batch_size": 2,
|
"speculative_disable_by_batch_size": 2,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"speculative_disable_by_batch_size": 2,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4,
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
}
|
}
|
||||||
# Try a range of common k, as well as large speculation.
|
# Try a range of common k, as well as large speculation.
|
||||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||||
])
|
] + [{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4,
|
||||||
|
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||||
|
"enable_chunked_prefill": False
|
||||||
}
|
}
|
||||||
# Try a range of common k.
|
# Try a range of common k.
|
||||||
for k in [1, 2, 3]
|
for k in [1, 2, 3]
|
||||||
])
|
] + [{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
|
} for k in [1, 2, 3]])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
|
|||||||
@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
|
|||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"ngram_prompt_lookup_max": 3,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
256,
|
256,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
seed: int):
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify greedy equality on a tiny model with different batch size."""
|
"""Verify greedy equality on a tiny model with different batch size."""
|
||||||
|
if prefill_chunk_size > 0:
|
||||||
|
common_llm_kwargs.update(
|
||||||
|
**{
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": prefill_chunk_size,
|
||||||
|
"max_num_seqs": prefill_chunk_size
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
common_llm_kwargs["enable_chunked_prefill"] = False
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
"speculative_model": "[ngram]",
|
"speculative_model": "[ngram]",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"enable_chunked_prefill": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"ngram_prompt_lookup_max": 3,
|
||||||
"speculative_disable_by_batch_size": 4
|
"speculative_disable_by_batch_size": 4
|
||||||
|
}, {
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"speculative_disable_by_batch_size": 4,
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
"max_num_batched_tokens": 4,
|
||||||
|
"max_num_seqs": 4
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens)
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
for sg in seq_group_metadata_list:
|
||||||
|
sg.is_prompt = False
|
||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
def test_ngram_algo_correctness_for_batches_match_all():
|
def test_ngram_algo_correctness_for_batches_match_all():
|
||||||
"""Verify our ngram algo find the right candidate in the prompt
|
"""Verify our ngram algo find the right candidate in the prompt
|
||||||
|
|
||||||
For the scenario find candidate in all batchs
|
For the scenario find candidate in all batches
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_size = 32
|
block_size = 32
|
||||||
@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens)
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
|
# Normally drafter is run on decode requests only; here we check the output
|
||||||
|
# of the ngram worker as it is the sole proposer that has no forward.
|
||||||
|
for sg in seq_group_metadata_list:
|
||||||
|
sg.is_prompt = False
|
||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
|||||||
@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
|
|||||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||||
@pytest.mark.parametrize('device', ['cuda'])
|
@pytest.mark.parametrize('device', ['cuda'])
|
||||||
|
@pytest.mark.parametrize('prefill_chunking', [False, True])
|
||||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||||
mixed_propose_len: bool, device: str) -> None:
|
mixed_propose_len: bool, device: str,
|
||||||
|
prefill_chunking: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||||
We test for both queries with the same propose length and different
|
We test for both queries with the same propose length and different
|
||||||
propose length.
|
propose length, as well as mixed prefill-decode batches.
|
||||||
"""
|
"""
|
||||||
seed = 0
|
seed = 0
|
||||||
block_size = 32
|
block_size = 32
|
||||||
@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
|||||||
if not mixed_propose_len:
|
if not mixed_propose_len:
|
||||||
propose_lens = [max_propose_len] * batch_size
|
propose_lens = [max_propose_len] * batch_size
|
||||||
else:
|
else:
|
||||||
non_zero_cnt = random.randint(0, batch_size)
|
# There must be at least 1 decode request, otherwise
|
||||||
|
# we have nothing to score (`_run_no_spec`).
|
||||||
|
non_zero_cnt = random.randint(1, batch_size)
|
||||||
propose_lens = [max_propose_len
|
propose_lens = [max_propose_len
|
||||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||||
random.shuffle(propose_lens)
|
random.shuffle(propose_lens)
|
||||||
|
|
||||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
|
||||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||||
max_propose_len,
|
max_propose_len,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_gpu_blocks=num_gpu_blocks)
|
num_gpu_blocks=num_gpu_blocks)
|
||||||
|
|
||||||
|
if mixed_propose_len and prefill_chunking and (n_prefills :=
|
||||||
|
batch_size - non_zero_cnt):
|
||||||
|
prefill, _, _ = create_batch(n_prefills,
|
||||||
|
None,
|
||||||
|
prefill_chunk_size=4,
|
||||||
|
block_size=block_size,
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
seq_ids=list(
|
||||||
|
range(batch_size,
|
||||||
|
batch_size + n_prefills)))
|
||||||
|
# re-order to guarantee prefill|decode order
|
||||||
|
target_group_metadatalist = [
|
||||||
|
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
|
||||||
|
if p > 0
|
||||||
|
]
|
||||||
|
seq_group_metadatalist = prefill + target_group_metadatalist
|
||||||
|
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
|
||||||
|
|
||||||
|
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||||
num_lookahead_slots=max_propose_len)
|
num_lookahead_slots=max_propose_len)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||||
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||||
SpecDecodeWorkerMetrics)
|
SpecDecodeWorkerMetrics)
|
||||||
@ -819,3 +820,84 @@ def test_handle_finished_requests():
|
|||||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||||
assert worker._seq_with_bonus_token_in_last_step == \
|
assert worker._seq_with_bonus_token_in_last_step == \
|
||||||
{4,5,10}
|
{4,5,10}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [3])
|
||||||
|
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||||
|
@pytest.mark.parametrize("batch_composition",
|
||||||
|
["prefill_only", "decode_only", "mixed"])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||||
|
"""
|
||||||
|
Verify SpecDecodeWorker calls match the expected flow.
|
||||||
|
"""
|
||||||
|
vocab_size = 32_000
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
worker = SpecDecodeWorker(draft_worker,
|
||||||
|
target_worker,
|
||||||
|
mock_spec_decode_sampler("rejection_sampler"),
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
|
exception_secret = 'artificial stop'
|
||||||
|
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||||
|
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||||
|
# and decodes (different seq_ids).
|
||||||
|
decodes, _, _ = create_batch(batch_size, k)
|
||||||
|
# Pre-chunking here, get 'batch_size' chunks.
|
||||||
|
prefill, _, _ = create_batch(batch_size,
|
||||||
|
k,
|
||||||
|
prefill_chunk_size=4,
|
||||||
|
seq_ids=list(range(batch_size,
|
||||||
|
batch_size * 2)))
|
||||||
|
|
||||||
|
if batch_composition == "prefill_only":
|
||||||
|
n_prefills = batch_size
|
||||||
|
elif batch_composition == "decode_only":
|
||||||
|
n_prefills = 0
|
||||||
|
else:
|
||||||
|
n_prefills = random.randint(1, batch_size - 1)
|
||||||
|
n_decodes = batch_size - n_prefills
|
||||||
|
|
||||||
|
prefill = random.sample(prefill, n_prefills)
|
||||||
|
decodes = random.sample(decodes, n_decodes)
|
||||||
|
target_group_metadata_list = prefill + decodes
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=target_group_metadata_list,
|
||||||
|
num_lookahead_slots=k)
|
||||||
|
|
||||||
|
target_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(1, batch_size * (k + 1)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
target_token_probs = torch.rand(1,
|
||||||
|
batch_size * (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
target_token_logprobs = torch.rand(1,
|
||||||
|
batch_size * (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
target_output = create_sampler_output_list(target_token_ids,
|
||||||
|
target_token_probs,
|
||||||
|
target_token_logprobs)
|
||||||
|
|
||||||
|
target_worker.execute_model.return_value = [target_output[0]]
|
||||||
|
|
||||||
|
if not len(decodes):
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
# no spec run (prefill only)
|
||||||
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
|
else:
|
||||||
|
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
# but first draft still counted
|
||||||
|
assert draft_worker.get_spec_proposals.call_count == 1
|
||||||
|
|||||||
@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
|
|||||||
return seq_grou_metadata_list
|
return seq_grou_metadata_list
|
||||||
|
|
||||||
|
|
||||||
|
def create_chunked_seq_group_metadata_from_prompt(
|
||||||
|
prompt: List[int],
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
chunk_size: int,
|
||||||
|
block_size: int,
|
||||||
|
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
|
||||||
|
|
||||||
|
if seq_id is None:
|
||||||
|
seq_id = 0
|
||||||
|
|
||||||
|
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||||
|
|
||||||
|
block_allocations = [
|
||||||
|
free_gpu_blocks.pop()
|
||||||
|
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||||
|
]
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||||
|
chunk_ids = prompt[idx:idx + chunk_size]
|
||||||
|
data = SequenceData.from_seqs(prompt)
|
||||||
|
data.update_num_computed_tokens(idx)
|
||||||
|
seq_data = {i: data}
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=str(seq_id),
|
||||||
|
is_prompt=True,
|
||||||
|
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||||
|
seq_data=seq_data,
|
||||||
|
sampling_params=SamplingParams(temperature=0.0),
|
||||||
|
block_tables={i: block_allocations},
|
||||||
|
token_chunk_size=len(chunk_ids)))
|
||||||
|
return seq_group_metadata_list
|
||||||
|
|
||||||
|
|
||||||
def assert_logprobs_dict_allclose(
|
def assert_logprobs_dict_allclose(
|
||||||
actual_logprobs: List[Dict[int, Logprob]],
|
actual_logprobs: List[Dict[int, Logprob]],
|
||||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||||
@ -198,7 +233,8 @@ def create_batch(batch_size,
|
|||||||
prev_output_token_len: int = 10,
|
prev_output_token_len: int = 10,
|
||||||
seq_ids: Optional[List[int]] = None,
|
seq_ids: Optional[List[int]] = None,
|
||||||
num_gpu_blocks: Optional[int] = None,
|
num_gpu_blocks: Optional[int] = None,
|
||||||
block_size: Optional[int] = None):
|
block_size: Optional[int] = None,
|
||||||
|
prefill_chunk_size: Optional[int] = None):
|
||||||
if block_size is None:
|
if block_size is None:
|
||||||
block_size = 8
|
block_size = 8
|
||||||
|
|
||||||
@ -213,15 +249,28 @@ def create_batch(batch_size,
|
|||||||
prompt_lens = prompt_len
|
prompt_lens = prompt_len
|
||||||
|
|
||||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||||
prev_output_tokens = [[
|
|
||||||
next(iterator) for _ in range(prev_output_token_len)
|
|
||||||
] for _ in range(batch_size)]
|
|
||||||
final_prompt_lens = [
|
|
||||||
len(prompt) + len(prev_output_token) + k + 1
|
|
||||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
|
||||||
]
|
|
||||||
|
|
||||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
if prefill_chunk_size:
|
||||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
# Create a batch of chunked prompts.
|
||||||
prev_output_tokens, seq_ids)
|
if not seq_ids:
|
||||||
|
seq_ids = list(range(len(prompts)))
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for p, sid in zip(prompts, seq_ids):
|
||||||
|
seq_group_metadata_list += \
|
||||||
|
create_chunked_seq_group_metadata_from_prompt(
|
||||||
|
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||||
|
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||||
|
prev_output_tokens = []
|
||||||
|
else:
|
||||||
|
prev_output_tokens = [[
|
||||||
|
next(iterator) for _ in range(prev_output_token_len)
|
||||||
|
] for _ in range(batch_size)]
|
||||||
|
final_prompt_lens = [
|
||||||
|
len(prompt) + len(prev_output_token) + k + 1
|
||||||
|
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||||
|
]
|
||||||
|
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||||
|
prev_output_tokens, seq_ids)
|
||||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||||
|
|||||||
@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
query_start_loc=self.query_start_loc[self.num_prefills:]
|
# Batch may be composed of prefill|decodes, adjust query start
|
||||||
|
# indices to refer to the start of decodes. E.g.
|
||||||
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||||
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||||
|
self.query_start_loc[self.num_prefills])
|
||||||
if self.query_start_loc is not None else None,
|
if self.query_start_loc is not None else None,
|
||||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||||
if self.seq_start_loc is not None else None,
|
if self.seq_start_loc is not None else None,
|
||||||
@ -903,7 +907,9 @@ def unified_flash_attention(
|
|||||||
# Decoding run.
|
# Decoding run.
|
||||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||||
# because different queries might have different lengths.
|
# because different queries might have different lengths.
|
||||||
|
|
||||||
assert decode_meta.max_decode_query_len is not None
|
assert decode_meta.max_decode_query_len is not None
|
||||||
|
# use only for actual varlen decoding
|
||||||
if decode_meta.max_decode_query_len > 1:
|
if decode_meta.max_decode_query_len > 1:
|
||||||
assert attn_type == AttentionType.DECODER, (
|
assert attn_type == AttentionType.DECODER, (
|
||||||
"Only decoder-only models support max_decode_query_len > 1")
|
"Only decoder-only models support max_decode_query_len > 1")
|
||||||
@ -949,8 +955,6 @@ def unified_flash_attention(
|
|||||||
assert prefill_output is not None
|
assert prefill_output is not None
|
||||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||||
|
|
||||||
# Chunked prefill does not work with speculative decoding.
|
|
||||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
decode_output = decode_output.squeeze(1)
|
decode_output = decode_output.squeeze(1)
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
|
|||||||
@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
block_tables=self.block_tables[self.num_prefills:],
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
use_cuda_graph=self.use_cuda_graph,
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
)
|
)
|
||||||
|
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||||
|
# to refer to the start of decodes when the two are split apart.
|
||||||
|
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||||
|
if self._cached_decode_metadata.query_start_loc is not None:
|
||||||
|
qs = self._cached_decode_metadata.query_start_loc
|
||||||
|
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
def advance_step(self,
|
||||||
|
|||||||
@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||||
cross_slot_mapping=self.cross_slot_mapping,
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
cross_block_tables=self.cross_block_tables)
|
cross_block_tables=self.cross_block_tables)
|
||||||
|
|
||||||
|
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||||
|
# to refer to the start of decodes when the two are split apart.
|
||||||
|
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||||
|
if self._cached_decode_metadata.query_start_loc is not None:
|
||||||
|
qs = self._cached_decode_metadata.query_start_loc
|
||||||
|
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -192,7 +192,6 @@ class ModelConfig:
|
|||||||
self.max_logprobs = max_logprobs
|
self.max_logprobs = max_logprobs
|
||||||
self.disable_sliding_window = disable_sliding_window
|
self.disable_sliding_window = disable_sliding_window
|
||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision, rope_scaling, rope_theta,
|
code_revision, rope_scaling, rope_theta,
|
||||||
config_format)
|
config_format)
|
||||||
@ -1317,13 +1316,6 @@ class SpeculativeConfig:
|
|||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{speculative_disable_by_batch_size=}")
|
f"{speculative_disable_by_batch_size=}")
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
|
||||||
# If the feature combo become valid
|
|
||||||
if enable_chunked_prefill:
|
|
||||||
raise ValueError(
|
|
||||||
"Speculative decoding and chunked prefill are "
|
|
||||||
f"currently mutually exclusive ({enable_chunked_prefill=}).")
|
|
||||||
|
|
||||||
# TODO: The user should be able to specify revision/max model len
|
# TODO: The user should be able to specify revision/max model len
|
||||||
# for the draft model. It is not currently supported.
|
# for the draft model. It is not currently supported.
|
||||||
draft_revision = None
|
draft_revision = None
|
||||||
@ -1390,6 +1382,12 @@ class SpeculativeConfig:
|
|||||||
f"num_speculative_tokens={n_predict}, but "
|
f"num_speculative_tokens={n_predict}, but "
|
||||||
f"{num_speculative_tokens=} was provided.")
|
f"{num_speculative_tokens=} was provided.")
|
||||||
|
|
||||||
|
if enable_chunked_prefill and draft_hf_config.model_type in (
|
||||||
|
"medusa", "mlp_speculator", "eagle"):
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked prefill and hidden-state based draft models are "
|
||||||
|
"not compatible.")
|
||||||
|
|
||||||
draft_model_config.max_model_len = (
|
draft_model_config.max_model_len = (
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
speculative_max_model_len,
|
speculative_max_model_len,
|
||||||
|
|||||||
@ -1147,6 +1147,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Update swapped requests.
|
# Update swapped requests.
|
||||||
self.swapped.extend(running_scheduled.swapped_out)
|
self.swapped.extend(running_scheduled.swapped_out)
|
||||||
|
# Put prefills first due to Attention backend ordering assumption.
|
||||||
return SchedulerOutputs(
|
return SchedulerOutputs(
|
||||||
scheduled_seq_groups=(prefills.seq_groups +
|
scheduled_seq_groups=(prefills.seq_groups +
|
||||||
running_scheduled.prefill_seq_groups +
|
running_scheduled.prefill_seq_groups +
|
||||||
|
|||||||
@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
sample for sample in samples
|
sample for sample in samples
|
||||||
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||||
]
|
]
|
||||||
assert valid_samples
|
|
||||||
|
|
||||||
self._process_seq_outputs(seq, valid_samples,
|
# When both spec-decode and pre-fill chunking are enabled, we
|
||||||
sequence_group.sampling_params)
|
# don't have guaranteed samples here (e.g. all -1s).
|
||||||
|
if valid_samples:
|
||||||
|
self._process_seq_outputs(seq, valid_samples,
|
||||||
|
sequence_group.sampling_params)
|
||||||
|
|
||||||
def _process_decode_and_stop(self, seq: Sequence,
|
def _process_decode_and_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
else:
|
else:
|
||||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||||
contracted = self._contract_batch(
|
contracted = self._contract_batch(
|
||||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
execute_model_req.seq_group_metadata_list,
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
num_scoring_tokens=num_scoring_tokens,
|
num_scoring_tokens=num_scoring_tokens,
|
||||||
@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
split_batch_by_proposal_len(
|
split_batch_by_proposal_len(
|
||||||
seq_group_metadata_list, proposal_lens_list)
|
seq_group_metadata_list, proposal_lens_list)
|
||||||
|
|
||||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
spec_expanded_seqs = self._create_scoring_model_input(
|
||||||
seq_group_metadata_list=spec_seqs,
|
seq_group_metadata_list=spec_seqs,
|
||||||
proposal_token_ids=proposal_token_ids_list,
|
proposal_token_ids=proposal_token_ids_list,
|
||||||
# NOTE: We determine the seq ids in the expanded batch using the
|
# NOTE: We determine the seq ids in the expanded batch using the
|
||||||
@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
num_scoring_tokens = len(spec_expanded_seqs)
|
||||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
# Batch speculative and non-speculative (e.g. chunked prefill) requests
|
||||||
|
# but make sure order is prefill|decode due to backend requirement.
|
||||||
|
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
|
||||||
|
|
||||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
num_scoring_tokens)
|
num_scoring_tokens)
|
||||||
|
|
||||||
def _contract_batch(
|
def _contract_batch(
|
||||||
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
|
||||||
non_spec_indices: List[int], spec_indices: List[int], k: int
|
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||||
|
spec_indices: List[int], k: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""Contract the expanded batch back into its original size.
|
"""Contract the expanded batch back into its original size.
|
||||||
@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
contracted_bs is the original batch size, and the batch size that the
|
contracted_bs is the original batch size, and the batch size that the
|
||||||
target_sampler_output will be contracted to.
|
target_sampler_output will be contracted to.
|
||||||
"""
|
"""
|
||||||
|
contracted_bs = len(contracted_seq_group_metadata_list)
|
||||||
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
||||||
non_spec_target_token_ids, non_spec_target_probs,
|
non_spec_target_token_ids, non_spec_target_probs,
|
||||||
non_spec_target_logprobs,
|
non_spec_target_logprobs,
|
||||||
@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
|
|
||||||
# The number of tokens in the expanded batch used for speculation is
|
# The number of tokens in the expanded batch used for speculation is
|
||||||
# equal to the total expanded batch size minus the number of samples for
|
# equal to the total expanded batch size minus the number of samples for
|
||||||
# non-speculative sequences.
|
# non-speculative sequences, prefill chunks with no out tokens included
|
||||||
non_spec_expanded_bs = len(non_spec_target_token_ids)
|
non_spec_expanded_bs = len(non_spec_indices)
|
||||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||||
|
|
||||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||||
@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
else:
|
else:
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
|
|
||||||
if non_spec_indices:
|
# Rule out prefills that produce no tokens.
|
||||||
|
non_spec_indices = [
|
||||||
|
idx for idx in non_spec_indices
|
||||||
|
if contracted_seq_group_metadata_list[idx].do_sample
|
||||||
|
]
|
||||||
|
if len(non_spec_indices):
|
||||||
all_tokens[non_spec_indices, :1] = \
|
all_tokens[non_spec_indices, :1] = \
|
||||||
non_spec_target_token_ids.unsqueeze(1)
|
non_spec_target_token_ids.unsqueeze(1)
|
||||||
all_probs[non_spec_indices, :1, :] = \
|
all_probs[non_spec_indices, :1, :] = \
|
||||||
@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
This function creates K+1 target SequenceGroupMetadata to take
|
This function creates K+1 target SequenceGroupMetadata to take
|
||||||
advantage of the bonus token.
|
advantage of the bonus token.
|
||||||
"""
|
"""
|
||||||
assert not input_seq_group_metadata.is_prompt, (
|
|
||||||
"Speculating on "
|
|
||||||
"prompts not yet supported")
|
|
||||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||||
"Beam search "
|
"Beam search "
|
||||||
"not supported in speculative decoding")
|
"not supported in speculative decoding")
|
||||||
@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
# and non spec sequences) and should be removed in the future. It can be
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
# done by supporting per-sequence proposal lens.
|
# done by supporting per-sequence proposal lens.
|
||||||
#
|
#
|
||||||
# First samples are from speculative scoring, latter samples are non-
|
# First samples are non-speculative, latter samples are from speculative
|
||||||
# speculative samples.
|
# scoring (prefill|decode order).
|
||||||
split_sizes = (num_scoring_tokens,
|
split_sizes = (sampler_output.sampled_token_ids.numel() -
|
||||||
sampler_output.sampled_token_ids.numel() -
|
num_scoring_tokens, num_scoring_tokens)
|
||||||
num_scoring_tokens)
|
(non_spec_probs,
|
||||||
(spec_probs, non_spec_probs
|
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
(non_spec_sampled_tokens, spec_sampled_tokens
|
||||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
|
||||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||||
(
|
(non_spec_logprobs,
|
||||||
spec_logprobs,
|
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
|
||||||
non_spec_logprobs,
|
|
||||||
) = sampler_output.logprobs.split(split_sizes)
|
|
||||||
|
|
||||||
if sampler_output.hidden_states is not None:
|
if sampler_output.hidden_states is not None:
|
||||||
(
|
(non_spec_hidden_states, spec_hidden_states
|
||||||
spec_hidden_states,
|
) = sampler_output.hidden_states.split(split_sizes)
|
||||||
non_spec_hidden_states,
|
|
||||||
) = sampler_output.hidden_states.split(split_sizes)
|
|
||||||
else:
|
else:
|
||||||
spec_hidden_states, non_spec_hidden_states = None, None
|
non_spec_hidden_states, spec_hidden_states = None, None
|
||||||
|
|
||||||
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||||
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||||
|
|||||||
@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
all_proposal_lengths = proposals.proposal_lens.tolist()
|
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||||
for i, seq_group_metadata in enumerate(
|
for i, seq_group_metadata in enumerate(
|
||||||
execute_model_req.seq_group_metadata_list):
|
execute_model_req.seq_group_metadata_list):
|
||||||
|
if all_proposal_lengths[i] == 0:
|
||||||
|
# Keep prompt seqs untouched (keep computed_tokens for chunks).
|
||||||
|
target_seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
continue
|
||||||
|
|
||||||
seq_data_dict = seq_group_metadata.seq_data
|
seq_data_dict = seq_group_metadata.seq_data
|
||||||
assert len(seq_data_dict) == 1
|
assert len(seq_data_dict) == 1
|
||||||
seq_id = next(iter(seq_data_dict.keys()))
|
seq_id = next(iter(seq_data_dict.keys()))
|
||||||
@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
new_seq_data.update_num_computed_tokens(
|
new_seq_data.update_num_computed_tokens(
|
||||||
len(prompt_token_ids) + len(output_token_ids) - 1)
|
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||||
|
|
||||||
# Ensure that the new sequence has at least one token
|
# Ensure that the new decode sequence has at least one token.
|
||||||
# because we only use mqa scorer in the decoding stage.
|
|
||||||
assert len(output_token_ids) >= 1
|
assert len(output_token_ids) >= 1
|
||||||
new_seq_data_dict = {target_seq_id: new_seq_data}
|
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||||
|
|
||||||
@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||||
},
|
},
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
token_chunk_size=1,
|
|
||||||
)
|
)
|
||||||
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||||
|
|
||||||
@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||||
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||||
else:
|
else:
|
||||||
|
# We either have decodes with different lens or prefill+decodes.
|
||||||
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||||
fill_value=-1)
|
fill_value=-1)
|
||||||
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||||
@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
fill_value=-float("inf"))
|
fill_value=-float("inf"))
|
||||||
target_token_ids = target_token_ids.flatten()
|
target_token_ids = target_token_ids.flatten()
|
||||||
start_loc = 0
|
start_loc = 0
|
||||||
for i, proposed_len in enumerate(all_proposal_lengths):
|
for i, (proposed_len, seq_meta) in enumerate(
|
||||||
output_len = proposed_len + 1
|
zip(all_proposal_lengths, target_seq_group_metadata_list)):
|
||||||
end_loc = start_loc + output_len
|
# Skip chunks with no output tokens.
|
||||||
all_tokens[
|
if seq_meta.do_sample:
|
||||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
output_len = proposed_len + 1
|
||||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
end_loc = start_loc + output_len
|
||||||
all_logprobs[
|
all_tokens[
|
||||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||||
start_loc = end_loc
|
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||||
|
all_logprobs[
|
||||||
|
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||||
|
start_loc = end_loc
|
||||||
|
|
||||||
hidden_states = None
|
hidden_states = None
|
||||||
if target_sampler_output.hidden_states is not None:
|
if target_sampler_output.hidden_states is not None:
|
||||||
|
|||||||
@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# none of the requests in the batch have spec decoding enabled.
|
# none of the requests in the batch have spec decoding enabled.
|
||||||
# In any of these cases, the proposer and scorer workers
|
# In any of these cases, the proposer and scorer workers
|
||||||
# are called normally.
|
# are called normally.
|
||||||
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
|
# We expect `num_speculative_tokens` to be None for prefills.
|
||||||
|
no_spec = all(
|
||||||
|
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
|
||||||
|
) or num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||||
sgm.num_speculative_tokens == 0
|
sgm.num_speculative_tokens == 0
|
||||||
for sgm in execute_model_req.seq_group_metadata_list)
|
for sgm in execute_model_req.seq_group_metadata_list)
|
||||||
|
|
||||||
@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
def _serialize_sampler_output_no_logprobs(
|
def _serialize_sampler_output_no_logprobs(
|
||||||
self, execute_model_req: ExecuteModelRequest,
|
self, execute_model_req: ExecuteModelRequest,
|
||||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
sampler_output: SamplerOutput) -> List[SamplerOutput]:
|
||||||
"""
|
"""
|
||||||
Creates and returns a `SamplerOutput` with only the token IDs being
|
Creates and returns a `SamplerOutput` with only the token IDs being
|
||||||
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
|
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
|
||||||
@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
if any(seq_output_prompt_logprobs) else \
|
if any(seq_output_prompt_logprobs) else \
|
||||||
sampler_output.sampled_token_ids).tolist()
|
sampler_output.sampled_token_ids).tolist()
|
||||||
|
|
||||||
seq_data_entries = (
|
seq_data_entries = [
|
||||||
(seq_id, seq_data) for sg in \
|
(seq_id, seq_data) for sg in \
|
||||||
execute_model_req.seq_group_metadata_list \
|
execute_model_req.seq_group_metadata_list \
|
||||||
for seq_id, seq_data in sg.seq_data.items()
|
for seq_id, seq_data in sg.seq_data.items()
|
||||||
)
|
if sg.do_sample # ignore empty token sequences
|
||||||
|
]
|
||||||
completion_seq_group_output_list: List[
|
completion_seq_group_output_list: List[
|
||||||
CompletionSequenceGroupOutput] = []
|
CompletionSequenceGroupOutput] = []
|
||||||
for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
|
output_index = 0
|
||||||
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
|
# Make sure the non-terminal prefill chunks are still aligned with
|
||||||
if needs_prompt_logprobs:
|
# their own empty output.
|
||||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
for seq_group_meta in execute_model_req.seq_group_metadata_list:
|
||||||
prompt_logprobs = [
|
# Since we can get chunks here, we dont always have a sampled token
|
||||||
create_logprobs_output(
|
# (only on last chunk) but we still have to provide an output.
|
||||||
token_id=p_token_id,
|
if not seq_group_meta.do_sample:
|
||||||
|
completion_seq_group_output_list.append(
|
||||||
|
CompletionSequenceGroupOutput(samples=[],
|
||||||
|
prompt_logprobs=None))
|
||||||
|
else:
|
||||||
|
# Sequence with output.
|
||||||
|
seq_id, seq_data = seq_data_entries[output_index]
|
||||||
|
needs_prompt_logprobs = seq_output_prompt_logprobs[
|
||||||
|
output_index]
|
||||||
|
if needs_prompt_logprobs:
|
||||||
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
prompt_logprobs = [
|
||||||
|
create_logprobs_output(
|
||||||
|
token_id=p_token_id,
|
||||||
|
token_id_logprob_rank=-1,
|
||||||
|
token_id_logprob=0.0,
|
||||||
|
topk_token_ids=[],
|
||||||
|
topk_logprobs=[],
|
||||||
|
)
|
||||||
|
# no prompt logprobs for the first token
|
||||||
|
for p_token_id in prompt_token_ids[1:]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
prompt_logprobs = None
|
||||||
|
completion_seq_group_output_list.append(
|
||||||
|
create_sequence_group_output(
|
||||||
|
token_id=sampled_token_ids_list[output_index][0],
|
||||||
token_id_logprob_rank=-1,
|
token_id_logprob_rank=-1,
|
||||||
token_id_logprob=0.0,
|
token_id_logprob=0.0,
|
||||||
|
seq_id=seq_id,
|
||||||
topk_token_ids=[],
|
topk_token_ids=[],
|
||||||
topk_logprobs=[],
|
topk_logprobs=[],
|
||||||
)
|
prompt_logprobs=prompt_logprobs))
|
||||||
# no prompt logprobs for the first token
|
output_index += 1
|
||||||
for p_token_id in prompt_token_ids[1:]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
prompt_logprobs = None
|
|
||||||
|
|
||||||
completion_seq_group_output_list.append(
|
return [SamplerOutput(outputs=completion_seq_group_output_list)]
|
||||||
create_sequence_group_output(
|
|
||||||
token_id=sampled_token_ids_list[index][0],
|
|
||||||
token_id_logprob_rank=-1,
|
|
||||||
token_id_logprob=0.0,
|
|
||||||
seq_id=seq_id,
|
|
||||||
topk_token_ids=[],
|
|
||||||
topk_logprobs=[],
|
|
||||||
prompt_logprobs=prompt_logprobs))
|
|
||||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||||
@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
hidden_states = sampler_output.hidden_states
|
hidden_states = sampler_output.hidden_states
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
# remove hidden_states for prompt tokens
|
# remove hidden_states for prompt tokens
|
||||||
|
# TODO Enable `return_hidden_states`: prefill chunks hidden states
|
||||||
|
# are pruned by the logits processor. Also, they should be arranged
|
||||||
|
# back into full-prefill latent. Address it to enable MLPSpeculator.
|
||||||
if any(seq.is_prompt
|
if any(seq.is_prompt
|
||||||
for seq in execute_model_req.seq_group_metadata_list):
|
for seq in execute_model_req.seq_group_metadata_list):
|
||||||
hidden_states = hidden_states[
|
hidden_states = hidden_states[
|
||||||
@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||||
if self._disable_logprobs else
|
if self._disable_logprobs else
|
||||||
sampler_output)
|
[sampler_output])
|
||||||
|
|
||||||
# Clear device tensors from sampler output. This reduces communication
|
# Clear device tensors from sampler output. This reduces communication
|
||||||
# overhead when the engine runs in a different process than the workers.
|
# overhead when the engine runs in a different process than the workers.
|
||||||
sampler_output.sampled_token_probs = None
|
sampler_output.sampled_token_probs = None
|
||||||
sampler_output.sampled_token_ids = None
|
sampler_output.sampled_token_ids = None
|
||||||
sampler_output.logprobs = None
|
sampler_output.logprobs = None
|
||||||
return [sampler_output_to_return]
|
return sampler_output_to_return
|
||||||
|
|
||||||
def _run_non_driver_rank(self) -> bool:
|
def _run_non_driver_rank(self) -> bool:
|
||||||
"""Run proposer and verifier model in non-driver workers. This is used
|
"""Run proposer and verifier model in non-driver workers. This is used
|
||||||
@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
This invokes the proposer worker to get k speculative tokens for each
|
This invokes the proposer worker to get k speculative tokens for each
|
||||||
sequence, then scores each speculative token using the scoring worker.
|
sequence, then scores each speculative token using the scoring worker.
|
||||||
|
|
||||||
|
When `enable_chunked_prefill` is set, scorer will batch decodes and
|
||||||
|
prefills, while proposer will sync its KV-cache by running an extra
|
||||||
|
forward on prefills.
|
||||||
|
|
||||||
Returns a list of SamplerOutput, each containing a single token per
|
Returns a list of SamplerOutput, each containing a single token per
|
||||||
sequence.
|
sequence.
|
||||||
"""
|
"""
|
||||||
|
# With prefill chunking, expect requests to have prompts first
|
||||||
|
# so that backend gets prefill|decode.
|
||||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||||
|
|
||||||
# Pass last hidden states from target model to proposer
|
# Pass last hidden states from target model to proposer
|
||||||
@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposals,
|
proposals,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
|
||||||
|
execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
|
||||||
|
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
|
||||||
|
# discard decodes that have already been processed by proposer.
|
||||||
|
non_spec_indices = [
|
||||||
|
idx for idx in non_spec_indices
|
||||||
|
if execute_model_req.seq_group_metadata_list[idx].is_prompt
|
||||||
|
]
|
||||||
|
if len(non_spec_indices):
|
||||||
|
all_hidden_states = proposal_scores.hidden_states
|
||||||
|
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
|
||||||
|
if all_hidden_states is not None:
|
||||||
|
prefill_hidden_states = all_hidden_states[non_spec_indices]
|
||||||
|
execute_model_req.previous_hidden_states = \
|
||||||
|
prepare_prefill_hidden_states(prefill_hidden_states)
|
||||||
|
# Sync proposer KV cache for prefills.
|
||||||
|
prefill_req = execute_model_req.clone(non_spec_seqs)
|
||||||
|
self.proposer_worker.execute_model(prefill_req)
|
||||||
|
|
||||||
with Timer() as verification_timer:
|
with Timer() as verification_timer:
|
||||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||||
@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.previous_hidden_states = HiddenStates(
|
self.previous_hidden_states = HiddenStates(
|
||||||
hidden_states, seq_group_metadata_list,
|
hidden_states, seq_group_metadata_list,
|
||||||
second_last_token_hidden_states)
|
second_last_token_hidden_states)
|
||||||
|
|
||||||
return accepted_token_ids, logprobs
|
return accepted_token_ids, logprobs
|
||||||
|
|
||||||
def _create_output_sampler_list(
|
def _create_output_sampler_list(
|
||||||
@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||||
|
|
||||||
# Construct the output on a per-step, per-sequence basis.
|
# Construct the output on a per-step, per-sequence basis.
|
||||||
|
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||||
|
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
|
||||||
sampler_output_list: List[SamplerOutput] = []
|
sampler_output_list: List[SamplerOutput] = []
|
||||||
for step_index in range(num_steps):
|
for step_index in range(num_steps):
|
||||||
if all(token_id == -1
|
if all(token_id == -1
|
||||||
@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# This is periodic because the rejection sampler emits metrics
|
# This is periodic because the rejection sampler emits metrics
|
||||||
# periodically.
|
# periodically.
|
||||||
self._maybe_log_stage_times(*stage_times)
|
self._maybe_log_stage_times(*stage_times)
|
||||||
|
|
||||||
return sampler_output_list
|
return sampler_output_list
|
||||||
|
|
||||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||||
|
|||||||
@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
proposal_probs=proposal_probs,
|
proposal_probs=proposal_probs,
|
||||||
proposal_lens=proposal_lens,
|
proposal_lens=proposal_lens,
|
||||||
no_proposals=maybe_sampler_output is None)
|
no_proposals=maybe_sampler_output is None)
|
||||||
|
|
||||||
return proposals
|
return proposals
|
||||||
|
|
||||||
def _split_by_proposal_len(
|
def _split_by_proposal_len(
|
||||||
@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||||
nonzero_proposal_len_indices: List[int] = []
|
nonzero_proposal_len_indices: List[int] = []
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
# The speculative decoding for this request has been disabled
|
# The speculative decoding for this request has either been disabled
|
||||||
# (e.g. due to high traffic).
|
# (e.g. due to high traffic) or this is a prompt request.
|
||||||
if seq_group_metadata.num_speculative_tokens == 0:
|
if (seq_group_metadata.is_prompt
|
||||||
|
or seq_group_metadata.num_speculative_tokens == 0):
|
||||||
proposal_lens.append(0)
|
proposal_lens.append(0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user