diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b9cb3858c006..5cb982a0811c 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -2,6 +2,7 @@ from itertools import cycle from typing import List, Optional, Sequence, Tuple, Union import pytest +import torch from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory @@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled( spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) assert spec_pos_logprob.rank == -1 assert spec_pos_logprob.logprob == 0.0 + if isinstance(spec_pos_logprob_token_id, torch.Tensor): + spec_pos_logprob_token_id = spec_pos_logprob_token_id.item() assert spec_pos_logprob_token_id in baseline_pos_logprobs @@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model, batch_size: int, max_output_len: int, seed: int = 0, - temperature: float = 0.0): + temperature: float = 0.0, + logprobs: Optional[int] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. @@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model, results = [] prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - for args, env in ((arg1, env1), (arg2, env2)): with RemoteOpenAIServer(model, args, @@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model, prompt=prompts, max_tokens=max_output_len, seed=seed, - temperature=temperature) + temperature=temperature, + logprobs=logprobs) results.append({ "test": "seeded_sampling", "text": [choice.text for choice in completion.choices], + "logprobs": [choice.logprobs for choice in completion.choices], "finish_reason": [choice.finish_reason for choice in completion.choices], "usage": @@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model, n = len(results) // 2 arg1_results = results[:n] arg2_results = results[n:] + # Separate logprobs to avoid asserting exact equality. + arg1_logprobs = [r.pop("logprobs") for r in arg1_results] + arg2_logprobs = [r.pop("logprobs") for r in arg2_results] + for arg1_result, arg2_result in zip(arg1_results, arg2_results): assert arg1_result == arg2_result, ( f"Results for {model=} are not the same with {arg1=} and {arg2=}. " f"{arg1_result=} != {arg2_result=}") + if logprobs: + for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs): + for l1, l2 in zip(logs1, logs2): + assert l1.tokens == l2.tokens diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 02cba9279514..7001ee4c007f 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -2,6 +2,8 @@ tensor parallelism. """ +from typing import Optional + import pytest import torch @@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, "--speculative-draft-tensor-parallel-size", "1", ])]) +@pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, + logprobs: Optional[int], batch_size: int, seed: int): """Verify spec decode works well with same and different TP size for the draft model with chunked prefill. """ + if logprobs: + test_llm_kwargs.extend( + ["--disable_logprobs_during_spec_decoding", "False"]) run_equality_correctness_test_tp(model, common_llm_kwargs, per_test_common_llm_kwargs, @@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, batch_size, max_output_len=32, seed=seed, - temperature=0.0) + temperature=0.0, + logprobs=logprobs) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 4cfca8b78e79..1a543606cb3f 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,26 +4,27 @@ import pytest from vllm import SamplingParams +from ..utils import maybe_enable_chunked_prefill from .conftest import run_equality_correctness_test @pytest.mark.parametrize( "common_llm_kwargs", [{ - "model_name": "JackFram/llama-68m", + "model_name": "JackFram/llama-160m", # Skip cuda graph recording for fast test. - "enforce_eager": True, + "enforce_eager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", + "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, }, { - "speculative_model": "JackFram/llama-160m", + "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": True, }]) @@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test ]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("logprobs", [1, 6]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12]) def test_logprobs_equality(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int, logprobs: int): - """Verify output logprobs are equal with and without speculative decoding. + seed: int, logprobs: int, prefill_chunk_size: int): + """Verify output logprobs are equal with and without speculative decoding, + as well as with and without chunked prefill. """ + maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index b8965606b3d0..dbcbc0db1088 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -21,6 +21,7 @@ correctess for the target model outputs. import pytest +from ..utils import maybe_enable_chunked_prefill from .conftest import run_equality_correctness_test # main model @@ -67,12 +68,14 @@ PRECISION = "float32" ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + seed: int, prefill_chunk_size: int): """Verify greedy equality with different batch size.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("logprobs", [1, 6]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int, logprobs: int): + seed: int, logprobs: int, + prefill_chunk_size: int): """Verify greedy equality with different batch size.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_e2e_greedy_correctness_cuda_graph( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + seed: int, prefill_chunk_size: int): """Verify greedy equality with cuda graph enabled and different batch sizes.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_e2e_greedy_correctness_with_preemption( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + seed: int, prefill_chunk_size: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption( 32, ]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_different_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + seed: int, prefill_chunk_size: int): """Verify that medusa speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, 32, ]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int): + output_len: int, seed: int, + prefill_chunk_size: int): """Verify that medusa speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, 32, ]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int): + output_len: int, seed: int, prefill_chunk_size: int): """Verify that speculative decoding generates the same output with batch expansion scorer and mqa scorer. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 183ff2f5db27..1fa1104f5d3a 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -25,6 +25,7 @@ import pytest from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size +from ..utils import maybe_enable_chunked_prefill from .conftest import run_equality_correctness_test # main model @@ -66,14 +67,16 @@ PRECISION = "float32" @pytest.mark.parametrize("output_len", [ 128, ]) -@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("batch_size", [4, 32]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + seed: int, prefill_chunk_size: int): """Verify greedy equality with different batch size.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("logprobs", [1, 6]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, seed: int, - logprobs: int): + logprobs: int, prefill_chunk_size: int): """Verify greedy equality with different batch size.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) + # NOTE Test is sensitive enough st if we don't enable chunked prefill + # scheduling on baseline too, we get slightly different logprobs, ending + # up sampling different tokens at the tail (ie top tokens don't change). + # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? + maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("output_len", [2048]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int): + batch_size: int, output_len: int, + prefill_chunk_size: int, seed: int): """Verify acceptance rate with different batch size and large output length.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("output_len", [64]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("temperature", [1.0]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - temperature: float, seed: int): + temperature: float, + prefill_chunk_size: int, seed: int): """Verify seeded runs produce the same output.""" + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) + maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, 128, ]) @pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_mlp_e2e_greedy_correctness_with_preemption( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + prefill_chunk_size: int, seed: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption( ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) def test_mlp_e2e_greedy_correctness_with_padding( vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + prefill_chunk_size: int, seed: int): """Verify greedy equality when the vocab dimension is padded """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) # Default pad_to is 64, test model has vocab_size of 32000 def patched_pad_vocab_size(vocab_size, pad_to=None): @@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding( # Use smaller output len for fast test. 32, ]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_mlp_different_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, seed: int, - output_len: int): + test_llm_kwargs, batch_size: int, + prefill_chunk_size: int, seed: int, output_len: int): """Verify that mlp speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, # Use smaller output len for fast test. 32, ]) +# Speculative decoding is disabled when sequences reach decoding and the batch +# consists of single-token requests. Hence we set `max_num_seqs` +# >= `speculative_disable_by_batch_size` to test feature interaction. +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, seed: int, + test_llm_kwargs, batch_size: int, + prefill_chunk_size: int, seed: int, output_len: int): """Verify that mlp speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Use smaller output len for fast test. 32, ]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int): + output_len: int, prefill_chunk_size: int, seed: int): """Verify that speculative decoding generates the same output with batch expansion scorer and mqa scorer. """ + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index a13cca41f99e..05ad468dd8bc 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": False, - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - }, -]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + "disable_logprobs_during_spec_decoding": False + }, { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, + "disable_logprobs_during_spec_decoding": False + }]) @pytest.mark.parametrize( "output_len", [ @@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( batch_size, max_output_len=output_len, seed=seed, + prompt_logprobs=2, + logprobs=2, + disable_logprobs=False, temperature=0.0, ensure_all_accepted=ensure_all_accepted) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index e53d169a8fcc..77f8b8998c8d 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -26,6 +26,7 @@ for the target model outputs. import pytest +from ..utils import maybe_enable_chunked_prefill from .conftest import run_equality_correctness_test @@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, + "speculative_disable_mqa_scorer": False, }, { "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, + "speculative_disable_mqa_scorer": True, }, ]) @pytest.mark.parametrize("output_len", [ @@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, prefill_chunk_size: int, seed: int): """Verify greedy equality on a tiny model with different batch size.""" - if prefill_chunk_size > 0: - common_llm_kwargs.update( - **{ - "enable_chunked_prefill": True, - "max_num_batched_tokens": prefill_chunk_size, - "max_num_seqs": prefill_chunk_size - }) - else: - common_llm_kwargs["enable_chunked_prefill"] = False + maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 0b1509d8b778..5a093dea16d4 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -60,6 +60,7 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, num_gpu_blocks = 2048 // block_size scorer_worker = create_worker(Worker, model_name, block_size, num_gpu_blocks, seed) + scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True scorer_worker.model_runner.model.sampler.\ should_modify_greedy_probs_inplace = True diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index caf7a7e625b4..d8c3af4c1cd1 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -754,6 +754,7 @@ def test_populate_seq_ids_with_bonus_tokens(): seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, + prompt_logprobs=None, k=k, stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index a4bfa6b2f384..2f883c2ff9b7 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -274,3 +274,15 @@ def create_batch(batch_size, prompts, num_gpu_blocks, block_size, final_prompt_lens, prev_output_tokens, seq_ids) return seq_group_metadata_list, prompts, prev_output_tokens + + +def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs): + if prefill_chunk_size > 0: + llm_kwargs.update( + **{ + "enable_chunked_prefill": True, + "max_num_batched_tokens": prefill_chunk_size, + "max_num_seqs": prefill_chunk_size + }) + else: + llm_kwargs["enable_chunked_prefill"] = False diff --git a/vllm/config.py b/vllm/config.py index dc1d61111548..7ab632d7e366 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1685,7 +1685,8 @@ class SpeculativeConfig: raise ValueError("Expect the batch size threshold of disabling " "speculative decoding is > 1, but got " f"{speculative_disable_by_batch_size=}") - + if (enable_chunked_prefill and speculative_model == "eagle"): + raise ValueError("Chunked prefill and EAGLE are not compatible.") # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None @@ -1752,12 +1753,6 @@ class SpeculativeConfig: f"num_speculative_tokens={n_predict}, but " f"{num_speculative_tokens=} was provided.") - if enable_chunked_prefill and draft_hf_config.model_type in ( - "medusa", "mlp_speculator", "eagle"): - raise ValueError( - "Chunked prefill and hidden-state based draft models are " - "not compatible.") - speculative_draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( target_parallel_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7da18d5f7d2e..ab67ae29723c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1010,8 +1010,23 @@ class LLMEngine: self.speculative_config # Organize outputs by [step][sequence group] instead of # [sequence group][step]. - outputs_by_sequence_group = create_output_by_sequence_group( - outputs, num_seq_groups=len(seq_group_metadata_list)) + if self.scheduler_config.is_multi_step: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, len(seq_group_metadata_list)) + elif self.speculative_config: + # Decodes are multi-steps while prefills are not, outputting at + # most 1 token. Separate them so that we can trigger chunk + # processing without having to pad or copy over prompts K times + # to match decodes structure (costly with prompt_logprobs). + num_prefills = sum(sg.is_prompt + for sg in seq_group_metadata_list) + prefills, decodes = outputs[:num_prefills], outputs[ + num_prefills:] + outputs_by_sequence_group = create_output_by_sequence_group( + decodes, + num_seq_groups=len(seq_group_metadata_list) - num_prefills) + outputs_by_sequence_group = [p.outputs for p in prefills + ] + outputs_by_sequence_group # We have outputs for multiple steps submitted in a single burst, # so invalidate is_first_step_output. is_first_step_output = None diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 01b9cdad963d..56fb9ba506a4 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -83,13 +83,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): if not non_spec_indices: # All sequence groups in batch have spec decoding enabled - contracted = self._contract_batch_all_spec( + return self._contract_batch_all_spec( target_sampler_output=target_sampler_output, proposals=proposals, ) else: # Batch has a mix of spec decode enabled and disabled seq groups - contracted = self._contract_batch( + return self._contract_batch( execute_model_req.seq_group_metadata_list, target_sampler_output=target_sampler_output, proposals=proposals, @@ -99,14 +99,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): k=execute_model_req.num_lookahead_slots, ) - all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted - return SpeculativeScores( - probs=all_probs, - token_ids=all_tokens, - logprobs=spec_logprobs, - hidden_states=all_hidden_states, - ) - def _expand_batch( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -143,13 +135,57 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) + def _contract_non_speculative( + self, scores: SpeculativeScores, + seq_group_metadata_list: List[SequenceGroupMetadata], + non_spec_indices: List[int], non_spec_outputs: SpeculativeScores, + has_prompt_log: bool) -> SpeculativeScores: + """ + Augment input `scores` with non-speculative requests outputs. + This includes decode requests with speculation turned off, as well + as prefill requests when `enable_chunked_prefill` is set. + For the latter, prefills are further separated into terminal and + non-terminal chunks (from which no token is sampled). + """ + if not non_spec_indices: + return scores + + if has_prompt_log: + # When prompt_logprobs is enabled, prefills yield output token + # (and respective prob) in the last entry (prompt|out): + # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..]. + # With chunked prefill, non-terminal chunks have -1 on each + # position: they're still picked, but they're discarded later. + seq_meta = seq_group_metadata_list + nospec_sizes = torch.tensor([ + seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 + for i in non_spec_indices + ]) + nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1) + else: + # In this case only sampled tokens are returned, select all. + nospec_sampled_token_idxs = list( + range(len(non_spec_outputs.token_ids))) + + scores.token_ids[non_spec_indices, :1] = \ + non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1) + scores.probs[non_spec_indices, :1, :] = \ + non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1) + scores.logprobs[non_spec_indices, :1, :] = \ + non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1) + if scores.hidden_states is not None: + assert non_spec_outputs.hidden_states is not None + scores.hidden_states[non_spec_indices, :1, :] = \ + non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1) + return scores + def _contract_batch( - self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata], - target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, - num_scoring_tokens: int, non_spec_indices: List[int], - spec_indices: List[int], k: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: + self, + contracted_seq_group_metadata_list: List[SequenceGroupMetadata], + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> SpeculativeScores: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -195,23 +231,28 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): else: all_hidden_states = None - # Rule out prefills that produce no tokens. - non_spec_indices = [ - idx for idx in non_spec_indices - if contracted_seq_group_metadata_list[idx].do_sample - ] - if len(non_spec_indices): - all_tokens[non_spec_indices, :1] = \ - non_spec_target_token_ids.unsqueeze(1) - all_probs[non_spec_indices, :1, :] = \ - non_spec_target_probs.unsqueeze(1) - all_logprobs[non_spec_indices, :1, :] = \ - non_spec_target_logprobs.unsqueeze(1) - if all_hidden_states is not None: - assert non_spec_target_hidden_states is not None - all_hidden_states[non_spec_indices, :1, :] = \ - non_spec_target_hidden_states.unsqueeze(1) + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in contracted_seq_group_metadata_list) + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is. + prompt_logprobs = None + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + elif not has_prompt_log: + # When prompt logprobs are not to be returned, + # we can ignore non-terminal chunks (no out token). + non_spec_indices = [ + idx for idx in non_spec_indices + if contracted_seq_group_metadata_list[idx].do_sample + ] + # "Contract" speculative. if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs @@ -219,14 +260,27 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): if all_hidden_states is not None: all_hidden_states[spec_indices] = target_hidden_states - return all_tokens, all_probs, all_logprobs, all_hidden_states + spec_scores = SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=all_hidden_states, + prompt_logprobs=prompt_logprobs) + + non_spec_outputs = SpeculativeScores( + probs=non_spec_target_probs, + token_ids=non_spec_target_token_ids, + logprobs=non_spec_target_logprobs, + hidden_states=non_spec_target_hidden_states) + # Contract remaining nonspec entries based on non_spec_indices, if any. + return self._contract_non_speculative( + spec_scores, contracted_seq_group_metadata_list, non_spec_indices, + non_spec_outputs, has_prompt_log) def _contract_batch_all_spec( self, target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: + ) -> SpeculativeScores: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -250,8 +304,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): target_hidden_states = target_hidden_states.reshape( *target_token_ids.shape, target_hidden_states.shape[-1]) - return (target_token_ids, target_probs, target_logprobs, - target_hidden_states) + return SpeculativeScores(probs=target_probs, + token_ids=target_token_ids, + logprobs=target_logprobs, + hidden_states=target_hidden_states, + prompt_logprobs=None) def _create_scoring_model_input( self, diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index a4fe0f13c8db..c39e98b6cca1 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Set, Union +from typing import List, Optional, Set, Union import torch -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, PromptLogprobs from vllm.worker.worker_base import WorkerBase @@ -54,6 +54,10 @@ class SpeculativeScores: # Optional last hidden states from the scoring model. hidden_states: Optional[torch.Tensor] = None + # Scoring model may also return logprobs for prompt tokens + # for each request, when chunked prefill is enabled. + prompt_logprobs: Optional[List[PromptLogprobs]] = None + def __repr__(self): return (f"SpeculativeScores(" f"probs={self.probs.shape}, " diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index cbf793e2043e..3aea2eabb414 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -72,9 +72,15 @@ class MQAScorer(SpeculativeScorer): target_token_ids = target_sampler_output.sampled_token_ids target_probs = target_sampler_output.sampled_token_probs target_logprobs = target_sampler_output.logprobs + prompt_logprobs = None + # If all requests have the same number of query tokens, we can avoid # the for loop to build output for better performance. if min(all_proposal_lengths) == k: + # Regular decodes only. + assert all(not sg.is_prompt + for sg in target_seq_group_metadata_list + if sg.is_prompt) bs, _ = proposals.proposal_token_ids.shape all_tokens = target_token_ids.reshape(bs, k + 1) all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) @@ -88,19 +94,56 @@ class MQAScorer(SpeculativeScorer): all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) target_token_ids = target_token_ids.flatten() - start_loc = 0 - for i, (proposed_len, seq_meta) in enumerate( - zip(all_proposal_lengths, target_seq_group_metadata_list)): + + # When prompt logprobs is enabled, lens of returned tensors go from + # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. + # We adjust stride accordingly to get the generated tokens and + # their probs, but pass on prompt_logprobs as is, since it may be + # that n_prompts >> K. + has_prompt_log = any((sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + for sg in target_seq_group_metadata_list) + # TODO (NickLucche) we should surface `disable_logprobs` as to not + # break abstraction to get its value. + if (not self._scorer_worker.model_runner.disable_logprobs\ + and has_prompt_log): + prompt_logprobs = [ + o.prompt_logprobs for o in target_sampler_output.outputs + ] + + # Split loop into prefill|decode for readability. + start_loc, i = 0, 0 + while i < len(target_seq_group_metadata_list + ) and target_seq_group_metadata_list[i].is_prompt: + seq_meta = target_seq_group_metadata_list[i] + end_loc = start_loc + if has_prompt_log: + end_loc += seq_meta.token_chunk_size + elif seq_meta.do_sample: + end_loc += 1 + # Skip chunks with no output tokens. if seq_meta.do_sample: - output_len = proposed_len + 1 - end_loc = start_loc + output_len - all_tokens[ - i, :output_len] = target_token_ids[start_loc:end_loc] - all_probs[i, :output_len] = target_probs[start_loc:end_loc] - all_logprobs[ - i, :output_len] = target_logprobs[start_loc:end_loc] - start_loc = end_loc + # Get sampled token (last position in chunk) and its prob. + all_tokens[i, 0] = target_token_ids[end_loc - 1] + all_probs[i, 0] = target_probs[end_loc - 1] + all_logprobs[i, 0] = target_logprobs[end_loc - 1] + + i += 1 + start_loc = end_loc + # Decodes. + while i < len(target_seq_group_metadata_list): + proposed_len, seq_meta = all_proposal_lengths[ + i], target_seq_group_metadata_list[i] + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc + i += 1 hidden_states = None if target_sampler_output.hidden_states is not None: @@ -110,4 +153,5 @@ class MQAScorer(SpeculativeScorer): return SpeculativeScores(probs=all_probs, token_ids=all_tokens, logprobs=all_logprobs, - hidden_states=hidden_states) + hidden_states=hidden_states, + prompt_logprobs=prompt_logprobs) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 0d66ede3d907..8e9802c7d333 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -563,50 +563,57 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): (seq_id, seq_data) for sg in \ execute_model_req.seq_group_metadata_list \ for seq_id, seq_data in sg.seq_data.items() - if sg.do_sample # ignore empty token sequences ] completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] output_index = 0 # Make sure the non-terminal prefill chunks are still aligned with # their own empty output. - for seq_group_meta in execute_model_req.seq_group_metadata_list: + for idx, seq_group_meta in enumerate( + execute_model_req.seq_group_metadata_list): + needs_prompt_logprobs = seq_output_prompt_logprobs[idx] + seq_id, seq_data = seq_data_entries[idx] + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + + # Some of these sequences may belong to non-terminal chunks, + # which may still have to report logprobs for prompts. + start = 1 if seq_data._num_computed_tokens == 0 \ + else seq_data._num_computed_tokens + end = (seq_data._num_computed_tokens + \ + seq_group_meta.token_chunk_size) + prompt_token_ids = prompt_token_ids[start:end] + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + else: + prompt_logprobs = None + # Since we can get chunks here, we dont always have a sampled token # (only on last chunk) but we still have to provide an output. if not seq_group_meta.do_sample: completion_seq_group_output_list.append( - CompletionSequenceGroupOutput(samples=[], - prompt_logprobs=None)) - else: - # Sequence with output. - seq_id, seq_data = seq_data_entries[output_index] - needs_prompt_logprobs = seq_output_prompt_logprobs[ - output_index] - if needs_prompt_logprobs: - prompt_token_ids = seq_data.get_prompt_token_ids() - prompt_logprobs = [ - create_logprobs_output( - token_id=p_token_id, - token_id_logprob_rank=-1, - token_id_logprob=0.0, - topk_token_ids=[], - topk_logprobs=[], - ) - # no prompt logprobs for the first token - for p_token_id in prompt_token_ids[1:] - ] - else: - prompt_logprobs = None - completion_seq_group_output_list.append( - create_sequence_group_output( - token_id=sampled_token_ids_list[output_index][0], - token_id_logprob_rank=-1, - token_id_logprob=0.0, - seq_id=seq_id, - topk_token_ids=[], - topk_logprobs=[], - prompt_logprobs=prompt_logprobs)) - output_index += 1 + CompletionSequenceGroupOutput( + samples=[], prompt_logprobs=prompt_logprobs)) + continue + + # Sequence with output. + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[output_index][0], + token_id_logprob_rank=-1, + token_id_logprob=0.0, + seq_id=seq_id, + topk_token_ids=[], + topk_logprobs=[], + prompt_logprobs=prompt_logprobs)) + output_index += 1 return [SamplerOutput(outputs=completion_seq_group_output_list)] @@ -624,24 +631,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): assert len(sampler_output) == 1 sampler_output = sampler_output[0] - # Store hidden states from target model execution. + # Store hidden states from target model execution, BxD. hidden_states = sampler_output.hidden_states if hidden_states is not None: - # remove hidden_states for prompt tokens - # TODO Enable `return_hidden_states`: prefill chunks hidden states - # are pruned by the logits processor. Also, they should be arranged - # back into full-prefill latent. Address it to enable MLPSpeculator. - if any(seq.is_prompt - for seq in execute_model_req.seq_group_metadata_list): + # Only decodes and prefill terminal chunks need a hidden state. + seq_group_meta_with_hidden = [ + sg for sg in execute_model_req.seq_group_metadata_list + if sg.do_sample + ] + if any(seq.is_prompt for seq in seq_group_meta_with_hidden): + # Drop hidden_states with no prediction (eg non-terminal chunks) hidden_states = hidden_states[ torch.where(sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] - if self.previous_hidden_states is None: + if self.previous_hidden_states is None and len( + seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( - hidden_states, execute_model_req.seq_group_metadata_list) - else: - self.previous_hidden_states.update( - hidden_states, execute_model_req.seq_group_metadata_list) + hidden_states, seq_group_meta_with_hidden) + elif self.previous_hidden_states and len( + seq_group_meta_with_hidden): + self.previous_hidden_states.update(hidden_states, + seq_group_meta_with_hidden) if not skip_proposer: # We prepare the prefill hidden states here so that there no @@ -752,13 +762,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ] if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states - # TODO fix `return_hidden_states`, same as in `_run_no_spec` if all_hidden_states is not None: prefill_hidden_states = all_hidden_states[non_spec_indices] execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states(prefill_hidden_states) # Sync proposer KV cache for prefills. prefill_req = execute_model_req.clone(non_spec_seqs) + # TODO avoid sampling here? self.proposer_worker.execute_model(prefill_req) with Timer() as verification_timer: @@ -774,6 +784,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, + prompt_logprobs=proposal_scores.prompt_logprobs + if not self._disable_logprobs else None, k=execute_model_req.num_lookahead_slots, stage_times=stage_times) @@ -845,19 +857,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() + # B x K+1 x D hidden_states = proposal_scores.hidden_states if hidden_states is not None: + # Only get terminal hidden states for next step + terminal_metadata = [ + sg for sg in seq_group_metadata_list if sg.do_sample + ] + # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] - accepted_index = accepted_token_ids + 1 # Convert -1 to 0 - accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) - index = accepted_index[:, None, None].expand(-1, 1, hs_size) + accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[ + accepted_index != VLLM_INVALID_TOKEN_ID] + accepted_index = accepted_index[ + accepted_index != VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, None].expand(-1, 1, + hs_size) # b x 1 x d second_last_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( - hidden_states, seq_group_metadata_list, + hidden_states, terminal_metadata, second_last_token_hidden_states) return accepted_token_ids, logprobs @@ -866,6 +891,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] + prompt_logprobs: Optional[ + torch.Tensor], # shape: [nprompt_tokens, vocab_size] k: int, stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: @@ -909,15 +936,89 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # Construct the output on a per-step, per-sequence basis. # Non-terminal prefill chunks will end up here as rows with just -1s - # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] + # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while + # terminal chunks will only have one generated token at time 0. sampler_output_list: List[SamplerOutput] = [] + + # Prefills are not multi-step (return at most 1 token), in order to + # avoid padding or repetition to fit decodes, we separate them. + for i, sg in enumerate(seq_group_metadata_list): + if not sg.is_prompt: + # Requests are ordered as prefills|decodes=>no more prefills. + break + num_logprobs = num_logprobs_per_seq[i] + seq_kwargs = dict(token_id=-1, + token_id_logprob_rank=0, + token_id_logprob=-float('inf'), + topk_token_ids=[-1] * num_logprobs, + topk_logprobs=[-float('inf')] * num_logprobs, + seq_id=seq_ids[i]) + # Terminal chunk, has token. + if sg.do_sample: + seq_kwargs.update( + dict( + token_id=accepted_token_ids[i][0].item(), + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + 0][i], + token_id_logprob=accepted_token_id_logprobs_by_step[0] + [i], + topk_token_ids=topk_indices_by_step[0][i] + [:num_logprobs], + # output only so step is 0 + topk_logprobs=topk_logprobs_by_step[0][i] + [:num_logprobs], + )) + needs_plogs = (sg.sampling_params.prompt_logprobs + and sg.sampling_params.prompt_logprobs > 0) + plogs = None + if prompt_logprobs is not None: + # Even non-terminal prompt chunks can have logprobs here. + plogs = prompt_logprobs[i] + elif needs_plogs: + # Prompt logprobs are requested but `_disable_logprobs` is set. + seq_data = next(iter(sg.seq_data.values())) + # Get only the tokens in this chunk! + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_token_ids = prompt_token_ids[ + seq_data. + _num_computed_tokens:seq_data._num_computed_tokens + + sg.token_chunk_size] + + is_first_chunk = seq_data._num_computed_tokens == 0 + # There's no prob generated for the first token in a sequence. + if is_first_chunk: + prompt_token_ids = prompt_token_ids[1:] + plogs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) for p_token_id in prompt_token_ids + ] + seq_kwargs.update(dict(prompt_logprobs=plogs)) + + sampler_output_list.append( + SamplerOutput( + outputs=[create_sequence_group_output( + **seq_kwargs)])) # type: ignore + + # Decodes, create one SamplerOutput per-step (at most K+1). for step_index in range(num_steps): - if all(token_id == -1 - for token_id in accepted_token_ids_by_step[step_index]): + if all(token_id == -1 for sg, token_id in zip( + seq_group_metadata_list, + accepted_token_ids_by_step[step_index]) + if not sg.is_prompt): break step_output_token_ids: List[CompletionSequenceGroupOutput] = [] for sequence_index in range(batch_size): + seq_meta = seq_group_metadata_list[sequence_index] + # Prompts already processed above. + if seq_meta.is_prompt: + continue + # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] step_output_token_ids.append( @@ -952,6 +1053,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # This is periodic because the rejection sampler emits metrics # periodically. self._maybe_log_stage_times(*stage_times) + # First `n_prefills` entries will contain prefills SamplerOutput when + # chunked prefill is enabled, the rest is decodes in multi-step format. return sampler_output_list def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,