mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:45:42 +08:00
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
parent
2bc3fbba0c
commit
6116ca8cd7
@ -2,6 +2,7 @@ from itertools import cycle
|
|||||||
from typing import List, Optional, Sequence, Tuple, Union
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
|
|||||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||||
assert spec_pos_logprob.rank == -1
|
assert spec_pos_logprob.rank == -1
|
||||||
assert spec_pos_logprob.logprob == 0.0
|
assert spec_pos_logprob.logprob == 0.0
|
||||||
|
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||||
|
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||||
|
|
||||||
|
|
||||||
@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_output_len: int,
|
max_output_len: int,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
temperature: float = 0.0):
|
temperature: float = 0.0,
|
||||||
|
logprobs: Optional[int] = None):
|
||||||
"""Helper method that compares the outputs of both the baseline LLM and
|
"""Helper method that compares the outputs of both the baseline LLM and
|
||||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
the same when temperature is zero.
|
the same when temperature is zero.
|
||||||
@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||||
|
|
||||||
for args, env in ((arg1, env1), (arg2, env2)):
|
for args, env in ((arg1, env1), (arg2, env2)):
|
||||||
with RemoteOpenAIServer(model,
|
with RemoteOpenAIServer(model,
|
||||||
args,
|
args,
|
||||||
@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
|
|||||||
prompt=prompts,
|
prompt=prompts,
|
||||||
max_tokens=max_output_len,
|
max_tokens=max_output_len,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=temperature)
|
temperature=temperature,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test":
|
"test":
|
||||||
"seeded_sampling",
|
"seeded_sampling",
|
||||||
"text": [choice.text for choice in completion.choices],
|
"text": [choice.text for choice in completion.choices],
|
||||||
|
"logprobs": [choice.logprobs for choice in completion.choices],
|
||||||
"finish_reason":
|
"finish_reason":
|
||||||
[choice.finish_reason for choice in completion.choices],
|
[choice.finish_reason for choice in completion.choices],
|
||||||
"usage":
|
"usage":
|
||||||
@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
|
|||||||
n = len(results) // 2
|
n = len(results) // 2
|
||||||
arg1_results = results[:n]
|
arg1_results = results[:n]
|
||||||
arg2_results = results[n:]
|
arg2_results = results[n:]
|
||||||
|
# Separate logprobs to avoid asserting exact equality.
|
||||||
|
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
|
||||||
|
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
|
||||||
|
|
||||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||||
assert arg1_result == arg2_result, (
|
assert arg1_result == arg2_result, (
|
||||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||||
f"{arg1_result=} != {arg2_result=}")
|
f"{arg1_result=} != {arg2_result=}")
|
||||||
|
if logprobs:
|
||||||
|
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
|
||||||
|
for l1, l2 in zip(logs1, logs2):
|
||||||
|
assert l1.tokens == l2.tokens
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
tensor parallelism.
|
tensor parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
|||||||
"--speculative-draft-tensor-parallel-size",
|
"--speculative-draft-tensor-parallel-size",
|
||||||
"1",
|
"1",
|
||||||
])])
|
])])
|
||||||
|
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
logprobs: Optional[int],
|
||||||
batch_size: int, seed: int):
|
batch_size: int, seed: int):
|
||||||
"""Verify spec decode works well with same and different TP size for
|
"""Verify spec decode works well with same and different TP size for
|
||||||
the draft model with chunked prefill.
|
the draft model with chunked prefill.
|
||||||
"""
|
"""
|
||||||
|
if logprobs:
|
||||||
|
test_llm_kwargs.extend(
|
||||||
|
["--disable_logprobs_during_spec_decoding", "False"])
|
||||||
run_equality_correctness_test_tp(model,
|
run_equality_correctness_test_tp(model,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
|||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=32,
|
max_output_len=32,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=0.0)
|
temperature=0.0,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|||||||
@ -4,26 +4,27 @@ import pytest
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from ..utils import maybe_enable_chunked_prefill
|
||||||
from .conftest import run_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model_name": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
}, {
|
}, {
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
}])
|
}])
|
||||||
@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
|
||||||
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
test_llm_kwargs, batch_size: int, output_len: int,
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int, logprobs: int):
|
seed: int, logprobs: int, prefill_chunk_size: int):
|
||||||
"""Verify output logprobs are equal with and without speculative decoding.
|
"""Verify output logprobs are equal with and without speculative decoding,
|
||||||
|
as well as with and without chunked prefill.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
|
|||||||
@ -21,6 +21,7 @@ correctess for the target model outputs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..utils import maybe_enable_chunked_prefill
|
||||||
from .conftest import run_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
@ -67,12 +68,14 @@ PRECISION = "float32"
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
seed: int):
|
seed: int, prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
seed: int, logprobs: int):
|
seed: int, logprobs: int,
|
||||||
|
prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int):
|
seed: int, prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with cuda graph enabled and different
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
batch sizes."""
|
batch sizes."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int):
|
seed: int, prefill_chunk_size: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
test_llm_kwargs, batch_size: int, output_len: int,
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int):
|
seed: int, prefill_chunk_size: int):
|
||||||
"""Verify that medusa speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
test_llm_kwargs, batch_size: int,
|
test_llm_kwargs, batch_size: int,
|
||||||
output_len: int, seed: int):
|
output_len: int, seed: int,
|
||||||
|
prefill_chunk_size: int):
|
||||||
"""Verify that medusa speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
output_len: int, seed: int):
|
output_len: int, seed: int, prefill_chunk_size: int):
|
||||||
"""Verify that speculative decoding generates the same output
|
"""Verify that speculative decoding generates the same output
|
||||||
with batch expansion scorer and mqa scorer.
|
with batch expansion scorer and mqa scorer.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||||
|
|
||||||
|
from ..utils import maybe_enable_chunked_prefill
|
||||||
from .conftest import run_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
@ -66,14 +67,16 @@ PRECISION = "float32"
|
|||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
seed: int):
|
seed: int, prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int, seed: int,
|
batch_size: int, output_len: int, seed: int,
|
||||||
logprobs: int):
|
logprobs: int, prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
|
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||||
|
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||||
|
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||||
|
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("output_len", [2048])
|
@pytest.mark.parametrize("output_len", [2048])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int, seed: int):
|
batch_size: int, output_len: int,
|
||||||
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify acceptance rate with different batch size and large output
|
"""Verify acceptance rate with different batch size and large output
|
||||||
length."""
|
length."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("output_len", [64])
|
@pytest.mark.parametrize("output_len", [64])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("temperature", [1.0])
|
@pytest.mark.parametrize("temperature", [1.0])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
temperature: float, seed: int):
|
temperature: float,
|
||||||
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify seeded runs produce the same output."""
|
"""Verify seeded runs produce the same output."""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
128,
|
128,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int):
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
seed: int):
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify greedy equality when the vocab dimension is padded
|
"""Verify greedy equality when the vocab dimension is padded
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
|
|
||||||
# Default pad_to is 64, test model has vocab_size of 32000
|
# Default pad_to is 64, test model has vocab_size of 32000
|
||||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||||
@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
|||||||
# Use smaller output len for fast test.
|
# Use smaller output len for fast test.
|
||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
test_llm_kwargs, batch_size: int, seed: int,
|
test_llm_kwargs, batch_size: int,
|
||||||
output_len: int):
|
prefill_chunk_size: int, seed: int, output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
# Use smaller output len for fast test.
|
# Use smaller output len for fast test.
|
||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
|
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||||
|
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||||
|
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
test_llm_kwargs, batch_size: int, seed: int,
|
test_llm_kwargs, batch_size: int,
|
||||||
|
prefill_chunk_size: int, seed: int,
|
||||||
output_len: int):
|
output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
# Use smaller output len for fast test.
|
# Use smaller output len for fast test.
|
||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
|
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
output_len: int, seed: int):
|
output_len: int, prefill_chunk_size: int, seed: int):
|
||||||
"""Verify that speculative decoding generates the same output
|
"""Verify that speculative decoding generates the same output
|
||||||
with batch expansion scorer and mqa scorer.
|
with batch expansion scorer and mqa scorer.
|
||||||
"""
|
"""
|
||||||
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
|
|||||||
@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
|||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
{
|
[{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
"disable_logprobs_during_spec_decoding": False
|
||||||
{
|
}, {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 3,
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4,
|
"max_num_seqs": 4,
|
||||||
},
|
"disable_logprobs_during_spec_decoding": False
|
||||||
])
|
}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
[
|
[
|
||||||
@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
prompt_logprobs=2,
|
||||||
|
logprobs=2,
|
||||||
|
disable_logprobs=False,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
ensure_all_accepted=ensure_all_accepted)
|
ensure_all_accepted=ensure_all_accepted)
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ for the target model outputs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..utils import maybe_enable_chunked_prefill
|
||||||
from .conftest import run_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test
|
|||||||
"speculative_model": "[ngram]",
|
"speculative_model": "[ngram]",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"speculative_disable_mqa_scorer": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_model": "[ngram]",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
prefill_chunk_size: int, seed: int):
|
prefill_chunk_size: int, seed: int):
|
||||||
"""Verify greedy equality on a tiny model with different batch size."""
|
"""Verify greedy equality on a tiny model with different batch size."""
|
||||||
if prefill_chunk_size > 0:
|
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||||
common_llm_kwargs.update(
|
|
||||||
**{
|
|
||||||
"enable_chunked_prefill": True,
|
|
||||||
"max_num_batched_tokens": prefill_chunk_size,
|
|
||||||
"max_num_seqs": prefill_chunk_size
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
common_llm_kwargs["enable_chunked_prefill"] = False
|
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
|
|||||||
@ -60,6 +60,7 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
|||||||
num_gpu_blocks = 2048 // block_size
|
num_gpu_blocks = 2048 // block_size
|
||||||
scorer_worker = create_worker(Worker, model_name, block_size,
|
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||||
num_gpu_blocks, seed)
|
num_gpu_blocks, seed)
|
||||||
|
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
|
||||||
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||||
scorer_worker.model_runner.model.sampler.\
|
scorer_worker.model_runner.model.sampler.\
|
||||||
should_modify_greedy_probs_inplace = True
|
should_modify_greedy_probs_inplace = True
|
||||||
|
|||||||
@ -754,6 +754,7 @@ def test_populate_seq_ids_with_bonus_tokens():
|
|||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
accepted_token_ids=accepted_token_ids,
|
accepted_token_ids=accepted_token_ids,
|
||||||
target_logprobs=target_token_logprobs,
|
target_logprobs=target_token_logprobs,
|
||||||
|
prompt_logprobs=None,
|
||||||
k=k,
|
k=k,
|
||||||
stage_times=(0, 0, 0))
|
stage_times=(0, 0, 0))
|
||||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||||
|
|||||||
@ -274,3 +274,15 @@ def create_batch(batch_size,
|
|||||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||||
prev_output_tokens, seq_ids)
|
prev_output_tokens, seq_ids)
|
||||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
|
||||||
|
if prefill_chunk_size > 0:
|
||||||
|
llm_kwargs.update(
|
||||||
|
**{
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
"max_num_batched_tokens": prefill_chunk_size,
|
||||||
|
"max_num_seqs": prefill_chunk_size
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
llm_kwargs["enable_chunked_prefill"] = False
|
||||||
|
|||||||
@ -1685,7 +1685,8 @@ class SpeculativeConfig:
|
|||||||
raise ValueError("Expect the batch size threshold of disabling "
|
raise ValueError("Expect the batch size threshold of disabling "
|
||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{speculative_disable_by_batch_size=}")
|
f"{speculative_disable_by_batch_size=}")
|
||||||
|
if (enable_chunked_prefill and speculative_model == "eagle"):
|
||||||
|
raise ValueError("Chunked prefill and EAGLE are not compatible.")
|
||||||
# TODO: The user should be able to specify revision/max model len
|
# TODO: The user should be able to specify revision/max model len
|
||||||
# for the draft model. It is not currently supported.
|
# for the draft model. It is not currently supported.
|
||||||
draft_revision = None
|
draft_revision = None
|
||||||
@ -1752,12 +1753,6 @@ class SpeculativeConfig:
|
|||||||
f"num_speculative_tokens={n_predict}, but "
|
f"num_speculative_tokens={n_predict}, but "
|
||||||
f"{num_speculative_tokens=} was provided.")
|
f"{num_speculative_tokens=} was provided.")
|
||||||
|
|
||||||
if enable_chunked_prefill and draft_hf_config.model_type in (
|
|
||||||
"medusa", "mlp_speculator", "eagle"):
|
|
||||||
raise ValueError(
|
|
||||||
"Chunked prefill and hidden-state based draft models are "
|
|
||||||
"not compatible.")
|
|
||||||
|
|
||||||
speculative_draft_tensor_parallel_size = \
|
speculative_draft_tensor_parallel_size = \
|
||||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||||
target_parallel_config,
|
target_parallel_config,
|
||||||
|
|||||||
@ -1010,8 +1010,23 @@ class LLMEngine:
|
|||||||
self.speculative_config
|
self.speculative_config
|
||||||
# Organize outputs by [step][sequence group] instead of
|
# Organize outputs by [step][sequence group] instead of
|
||||||
# [sequence group][step].
|
# [sequence group][step].
|
||||||
|
if self.scheduler_config.is_multi_step:
|
||||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||||
outputs, num_seq_groups=len(seq_group_metadata_list))
|
outputs, len(seq_group_metadata_list))
|
||||||
|
elif self.speculative_config:
|
||||||
|
# Decodes are multi-steps while prefills are not, outputting at
|
||||||
|
# most 1 token. Separate them so that we can trigger chunk
|
||||||
|
# processing without having to pad or copy over prompts K times
|
||||||
|
# to match decodes structure (costly with prompt_logprobs).
|
||||||
|
num_prefills = sum(sg.is_prompt
|
||||||
|
for sg in seq_group_metadata_list)
|
||||||
|
prefills, decodes = outputs[:num_prefills], outputs[
|
||||||
|
num_prefills:]
|
||||||
|
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||||
|
decodes,
|
||||||
|
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
||||||
|
outputs_by_sequence_group = [p.outputs for p in prefills
|
||||||
|
] + outputs_by_sequence_group
|
||||||
# We have outputs for multiple steps submitted in a single burst,
|
# We have outputs for multiple steps submitted in a single burst,
|
||||||
# so invalidate is_first_step_output.
|
# so invalidate is_first_step_output.
|
||||||
is_first_step_output = None
|
is_first_step_output = None
|
||||||
|
|||||||
@ -83,13 +83,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
|
|
||||||
if not non_spec_indices:
|
if not non_spec_indices:
|
||||||
# All sequence groups in batch have spec decoding enabled
|
# All sequence groups in batch have spec decoding enabled
|
||||||
contracted = self._contract_batch_all_spec(
|
return self._contract_batch_all_spec(
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||||
contracted = self._contract_batch(
|
return self._contract_batch(
|
||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
@ -99,14 +99,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
k=execute_model_req.num_lookahead_slots,
|
k=execute_model_req.num_lookahead_slots,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
|
|
||||||
return SpeculativeScores(
|
|
||||||
probs=all_probs,
|
|
||||||
token_ids=all_tokens,
|
|
||||||
logprobs=spec_logprobs,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _expand_batch(
|
def _expand_batch(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
@ -143,13 +135,57 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
num_scoring_tokens)
|
num_scoring_tokens)
|
||||||
|
|
||||||
|
def _contract_non_speculative(
|
||||||
|
self, scores: SpeculativeScores,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||||
|
has_prompt_log: bool) -> SpeculativeScores:
|
||||||
|
"""
|
||||||
|
Augment input `scores` with non-speculative requests outputs.
|
||||||
|
This includes decode requests with speculation turned off, as well
|
||||||
|
as prefill requests when `enable_chunked_prefill` is set.
|
||||||
|
For the latter, prefills are further separated into terminal and
|
||||||
|
non-terminal chunks (from which no token is sampled).
|
||||||
|
"""
|
||||||
|
if not non_spec_indices:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
if has_prompt_log:
|
||||||
|
# When prompt_logprobs is enabled, prefills yield output token
|
||||||
|
# (and respective prob) in the last entry (prompt|out):
|
||||||
|
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||||
|
# With chunked prefill, non-terminal chunks have -1 on each
|
||||||
|
# position: they're still picked, but they're discarded later.
|
||||||
|
seq_meta = seq_group_metadata_list
|
||||||
|
nospec_sizes = torch.tensor([
|
||||||
|
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||||
|
for i in non_spec_indices
|
||||||
|
])
|
||||||
|
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||||
|
else:
|
||||||
|
# In this case only sampled tokens are returned, select all.
|
||||||
|
nospec_sampled_token_idxs = list(
|
||||||
|
range(len(non_spec_outputs.token_ids)))
|
||||||
|
|
||||||
|
scores.token_ids[non_spec_indices, :1] = \
|
||||||
|
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||||
|
scores.probs[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||||
|
scores.logprobs[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||||
|
if scores.hidden_states is not None:
|
||||||
|
assert non_spec_outputs.hidden_states is not None
|
||||||
|
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||||
|
return scores
|
||||||
|
|
||||||
def _contract_batch(
|
def _contract_batch(
|
||||||
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
|
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
target_sampler_output: SamplerOutput,
|
||||||
spec_indices: List[int], k: int
|
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
non_spec_indices: List[int], spec_indices: List[int],
|
||||||
Optional[torch.Tensor]]:
|
k: int) -> SpeculativeScores:
|
||||||
"""Contract the expanded batch back into its original size.
|
"""Contract the expanded batch back into its original size.
|
||||||
This maps the scores of speculative tokens back to their original
|
This maps the scores of speculative tokens back to their original
|
||||||
sequences.
|
sequences.
|
||||||
@ -195,23 +231,28 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
else:
|
else:
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
|
|
||||||
# Rule out prefills that produce no tokens.
|
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||||
|
and sg.sampling_params.prompt_logprobs > 0)
|
||||||
|
for sg in contracted_seq_group_metadata_list)
|
||||||
|
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||||
|
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||||
|
# We adjust stride accordingly to get the generated tokens and
|
||||||
|
# their probs, but pass on prompt_logprobs as is.
|
||||||
|
prompt_logprobs = None
|
||||||
|
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||||
|
and has_prompt_log):
|
||||||
|
prompt_logprobs = [
|
||||||
|
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||||
|
]
|
||||||
|
elif not has_prompt_log:
|
||||||
|
# When prompt logprobs are not to be returned,
|
||||||
|
# we can ignore non-terminal chunks (no out token).
|
||||||
non_spec_indices = [
|
non_spec_indices = [
|
||||||
idx for idx in non_spec_indices
|
idx for idx in non_spec_indices
|
||||||
if contracted_seq_group_metadata_list[idx].do_sample
|
if contracted_seq_group_metadata_list[idx].do_sample
|
||||||
]
|
]
|
||||||
if len(non_spec_indices):
|
|
||||||
all_tokens[non_spec_indices, :1] = \
|
|
||||||
non_spec_target_token_ids.unsqueeze(1)
|
|
||||||
all_probs[non_spec_indices, :1, :] = \
|
|
||||||
non_spec_target_probs.unsqueeze(1)
|
|
||||||
all_logprobs[non_spec_indices, :1, :] = \
|
|
||||||
non_spec_target_logprobs.unsqueeze(1)
|
|
||||||
if all_hidden_states is not None:
|
|
||||||
assert non_spec_target_hidden_states is not None
|
|
||||||
all_hidden_states[non_spec_indices, :1, :] = \
|
|
||||||
non_spec_target_hidden_states.unsqueeze(1)
|
|
||||||
|
|
||||||
|
# "Contract" speculative.
|
||||||
if spec_indices:
|
if spec_indices:
|
||||||
all_tokens[spec_indices] = target_token_ids
|
all_tokens[spec_indices] = target_token_ids
|
||||||
all_probs[spec_indices] = target_probs
|
all_probs[spec_indices] = target_probs
|
||||||
@ -219,14 +260,27 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
if all_hidden_states is not None:
|
if all_hidden_states is not None:
|
||||||
all_hidden_states[spec_indices] = target_hidden_states
|
all_hidden_states[spec_indices] = target_hidden_states
|
||||||
|
|
||||||
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
spec_scores = SpeculativeScores(probs=all_probs,
|
||||||
|
token_ids=all_tokens,
|
||||||
|
logprobs=all_logprobs,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
|
||||||
|
non_spec_outputs = SpeculativeScores(
|
||||||
|
probs=non_spec_target_probs,
|
||||||
|
token_ids=non_spec_target_token_ids,
|
||||||
|
logprobs=non_spec_target_logprobs,
|
||||||
|
hidden_states=non_spec_target_hidden_states)
|
||||||
|
# Contract remaining nonspec entries based on non_spec_indices, if any.
|
||||||
|
return self._contract_non_speculative(
|
||||||
|
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
|
||||||
|
non_spec_outputs, has_prompt_log)
|
||||||
|
|
||||||
def _contract_batch_all_spec(
|
def _contract_batch_all_spec(
|
||||||
self,
|
self,
|
||||||
target_sampler_output: SamplerOutput,
|
target_sampler_output: SamplerOutput,
|
||||||
proposals: SpeculativeProposals,
|
proposals: SpeculativeProposals,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> SpeculativeScores:
|
||||||
Optional[torch.Tensor]]:
|
|
||||||
"""Contract the expanded batch back into its original size.
|
"""Contract the expanded batch back into its original size.
|
||||||
This maps the scores of speculative tokens back to their original
|
This maps the scores of speculative tokens back to their original
|
||||||
sequences.
|
sequences.
|
||||||
@ -250,8 +304,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
target_hidden_states = target_hidden_states.reshape(
|
target_hidden_states = target_hidden_states.reshape(
|
||||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||||
|
|
||||||
return (target_token_ids, target_probs, target_logprobs,
|
return SpeculativeScores(probs=target_probs,
|
||||||
target_hidden_states)
|
token_ids=target_token_ids,
|
||||||
|
logprobs=target_logprobs,
|
||||||
|
hidden_states=target_hidden_states,
|
||||||
|
prompt_logprobs=None)
|
||||||
|
|
||||||
def _create_scoring_model_input(
|
def _create_scoring_model_input(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Set, Union
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +54,10 @@ class SpeculativeScores:
|
|||||||
# Optional last hidden states from the scoring model.
|
# Optional last hidden states from the scoring model.
|
||||||
hidden_states: Optional[torch.Tensor] = None
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Scoring model may also return logprobs for prompt tokens
|
||||||
|
# for each request, when chunked prefill is enabled.
|
||||||
|
prompt_logprobs: Optional[List[PromptLogprobs]] = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"SpeculativeScores("
|
return (f"SpeculativeScores("
|
||||||
f"probs={self.probs.shape}, "
|
f"probs={self.probs.shape}, "
|
||||||
|
|||||||
@ -72,9 +72,15 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
target_token_ids = target_sampler_output.sampled_token_ids
|
target_token_ids = target_sampler_output.sampled_token_ids
|
||||||
target_probs = target_sampler_output.sampled_token_probs
|
target_probs = target_sampler_output.sampled_token_probs
|
||||||
target_logprobs = target_sampler_output.logprobs
|
target_logprobs = target_sampler_output.logprobs
|
||||||
|
prompt_logprobs = None
|
||||||
|
|
||||||
# If all requests have the same number of query tokens, we can avoid
|
# If all requests have the same number of query tokens, we can avoid
|
||||||
# the for loop to build output for better performance.
|
# the for loop to build output for better performance.
|
||||||
if min(all_proposal_lengths) == k:
|
if min(all_proposal_lengths) == k:
|
||||||
|
# Regular decodes only.
|
||||||
|
assert all(not sg.is_prompt
|
||||||
|
for sg in target_seq_group_metadata_list
|
||||||
|
if sg.is_prompt)
|
||||||
bs, _ = proposals.proposal_token_ids.shape
|
bs, _ = proposals.proposal_token_ids.shape
|
||||||
all_tokens = target_token_ids.reshape(bs, k + 1)
|
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||||
@ -88,11 +94,47 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||||
fill_value=-float("inf"))
|
fill_value=-float("inf"))
|
||||||
target_token_ids = target_token_ids.flatten()
|
target_token_ids = target_token_ids.flatten()
|
||||||
start_loc = 0
|
|
||||||
for i, (proposed_len, seq_meta) in enumerate(
|
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||||
zip(all_proposal_lengths, target_seq_group_metadata_list)):
|
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||||
|
# We adjust stride accordingly to get the generated tokens and
|
||||||
|
# their probs, but pass on prompt_logprobs as is, since it may be
|
||||||
|
# that n_prompts >> K.
|
||||||
|
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||||
|
and sg.sampling_params.prompt_logprobs > 0)
|
||||||
|
for sg in target_seq_group_metadata_list)
|
||||||
|
# TODO (NickLucche) we should surface `disable_logprobs` as to not
|
||||||
|
# break abstraction to get its value.
|
||||||
|
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||||
|
and has_prompt_log):
|
||||||
|
prompt_logprobs = [
|
||||||
|
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
# Split loop into prefill|decode for readability.
|
||||||
|
start_loc, i = 0, 0
|
||||||
|
while i < len(target_seq_group_metadata_list
|
||||||
|
) and target_seq_group_metadata_list[i].is_prompt:
|
||||||
|
seq_meta = target_seq_group_metadata_list[i]
|
||||||
|
end_loc = start_loc
|
||||||
|
if has_prompt_log:
|
||||||
|
end_loc += seq_meta.token_chunk_size
|
||||||
|
elif seq_meta.do_sample:
|
||||||
|
end_loc += 1
|
||||||
|
|
||||||
# Skip chunks with no output tokens.
|
# Skip chunks with no output tokens.
|
||||||
if seq_meta.do_sample:
|
if seq_meta.do_sample:
|
||||||
|
# Get sampled token (last position in chunk) and its prob.
|
||||||
|
all_tokens[i, 0] = target_token_ids[end_loc - 1]
|
||||||
|
all_probs[i, 0] = target_probs[end_loc - 1]
|
||||||
|
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
start_loc = end_loc
|
||||||
|
# Decodes.
|
||||||
|
while i < len(target_seq_group_metadata_list):
|
||||||
|
proposed_len, seq_meta = all_proposal_lengths[
|
||||||
|
i], target_seq_group_metadata_list[i]
|
||||||
output_len = proposed_len + 1
|
output_len = proposed_len + 1
|
||||||
end_loc = start_loc + output_len
|
end_loc = start_loc + output_len
|
||||||
all_tokens[
|
all_tokens[
|
||||||
@ -101,6 +143,7 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
all_logprobs[
|
all_logprobs[
|
||||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||||
start_loc = end_loc
|
start_loc = end_loc
|
||||||
|
i += 1
|
||||||
|
|
||||||
hidden_states = None
|
hidden_states = None
|
||||||
if target_sampler_output.hidden_states is not None:
|
if target_sampler_output.hidden_states is not None:
|
||||||
@ -110,4 +153,5 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
return SpeculativeScores(probs=all_probs,
|
return SpeculativeScores(probs=all_probs,
|
||||||
token_ids=all_tokens,
|
token_ids=all_tokens,
|
||||||
logprobs=all_logprobs,
|
logprobs=all_logprobs,
|
||||||
hidden_states=hidden_states)
|
hidden_states=hidden_states,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
|||||||
@ -563,27 +563,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
(seq_id, seq_data) for sg in \
|
(seq_id, seq_data) for sg in \
|
||||||
execute_model_req.seq_group_metadata_list \
|
execute_model_req.seq_group_metadata_list \
|
||||||
for seq_id, seq_data in sg.seq_data.items()
|
for seq_id, seq_data in sg.seq_data.items()
|
||||||
if sg.do_sample # ignore empty token sequences
|
|
||||||
]
|
]
|
||||||
completion_seq_group_output_list: List[
|
completion_seq_group_output_list: List[
|
||||||
CompletionSequenceGroupOutput] = []
|
CompletionSequenceGroupOutput] = []
|
||||||
output_index = 0
|
output_index = 0
|
||||||
# Make sure the non-terminal prefill chunks are still aligned with
|
# Make sure the non-terminal prefill chunks are still aligned with
|
||||||
# their own empty output.
|
# their own empty output.
|
||||||
for seq_group_meta in execute_model_req.seq_group_metadata_list:
|
for idx, seq_group_meta in enumerate(
|
||||||
# Since we can get chunks here, we dont always have a sampled token
|
execute_model_req.seq_group_metadata_list):
|
||||||
# (only on last chunk) but we still have to provide an output.
|
needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
|
||||||
if not seq_group_meta.do_sample:
|
seq_id, seq_data = seq_data_entries[idx]
|
||||||
completion_seq_group_output_list.append(
|
|
||||||
CompletionSequenceGroupOutput(samples=[],
|
|
||||||
prompt_logprobs=None))
|
|
||||||
else:
|
|
||||||
# Sequence with output.
|
|
||||||
seq_id, seq_data = seq_data_entries[output_index]
|
|
||||||
needs_prompt_logprobs = seq_output_prompt_logprobs[
|
|
||||||
output_index]
|
|
||||||
if needs_prompt_logprobs:
|
if needs_prompt_logprobs:
|
||||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
|
||||||
|
# Some of these sequences may belong to non-terminal chunks,
|
||||||
|
# which may still have to report logprobs for prompts.
|
||||||
|
start = 1 if seq_data._num_computed_tokens == 0 \
|
||||||
|
else seq_data._num_computed_tokens
|
||||||
|
end = (seq_data._num_computed_tokens + \
|
||||||
|
seq_group_meta.token_chunk_size)
|
||||||
|
prompt_token_ids = prompt_token_ids[start:end]
|
||||||
prompt_logprobs = [
|
prompt_logprobs = [
|
||||||
create_logprobs_output(
|
create_logprobs_output(
|
||||||
token_id=p_token_id,
|
token_id=p_token_id,
|
||||||
@ -591,12 +590,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
token_id_logprob=0.0,
|
token_id_logprob=0.0,
|
||||||
topk_token_ids=[],
|
topk_token_ids=[],
|
||||||
topk_logprobs=[],
|
topk_logprobs=[],
|
||||||
)
|
) for p_token_id in prompt_token_ids
|
||||||
# no prompt logprobs for the first token
|
|
||||||
for p_token_id in prompt_token_ids[1:]
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
prompt_logprobs = None
|
prompt_logprobs = None
|
||||||
|
|
||||||
|
# Since we can get chunks here, we dont always have a sampled token
|
||||||
|
# (only on last chunk) but we still have to provide an output.
|
||||||
|
if not seq_group_meta.do_sample:
|
||||||
|
completion_seq_group_output_list.append(
|
||||||
|
CompletionSequenceGroupOutput(
|
||||||
|
samples=[], prompt_logprobs=prompt_logprobs))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sequence with output.
|
||||||
completion_seq_group_output_list.append(
|
completion_seq_group_output_list.append(
|
||||||
create_sequence_group_output(
|
create_sequence_group_output(
|
||||||
token_id=sampled_token_ids_list[output_index][0],
|
token_id=sampled_token_ids_list[output_index][0],
|
||||||
@ -624,24 +631,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
assert len(sampler_output) == 1
|
assert len(sampler_output) == 1
|
||||||
sampler_output = sampler_output[0]
|
sampler_output = sampler_output[0]
|
||||||
|
|
||||||
# Store hidden states from target model execution.
|
# Store hidden states from target model execution, BxD.
|
||||||
hidden_states = sampler_output.hidden_states
|
hidden_states = sampler_output.hidden_states
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
# remove hidden_states for prompt tokens
|
# Only decodes and prefill terminal chunks need a hidden state.
|
||||||
# TODO Enable `return_hidden_states`: prefill chunks hidden states
|
seq_group_meta_with_hidden = [
|
||||||
# are pruned by the logits processor. Also, they should be arranged
|
sg for sg in execute_model_req.seq_group_metadata_list
|
||||||
# back into full-prefill latent. Address it to enable MLPSpeculator.
|
if sg.do_sample
|
||||||
if any(seq.is_prompt
|
]
|
||||||
for seq in execute_model_req.seq_group_metadata_list):
|
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||||
|
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||||
hidden_states = hidden_states[
|
hidden_states = hidden_states[
|
||||||
torch.where(sampler_output.sampled_token_ids -
|
torch.where(sampler_output.sampled_token_ids -
|
||||||
VLLM_INVALID_TOKEN_ID)[0]]
|
VLLM_INVALID_TOKEN_ID)[0]]
|
||||||
if self.previous_hidden_states is None:
|
if self.previous_hidden_states is None and len(
|
||||||
|
seq_group_meta_with_hidden):
|
||||||
self.previous_hidden_states = HiddenStates(
|
self.previous_hidden_states = HiddenStates(
|
||||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
hidden_states, seq_group_meta_with_hidden)
|
||||||
else:
|
elif self.previous_hidden_states and len(
|
||||||
self.previous_hidden_states.update(
|
seq_group_meta_with_hidden):
|
||||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
self.previous_hidden_states.update(hidden_states,
|
||||||
|
seq_group_meta_with_hidden)
|
||||||
|
|
||||||
if not skip_proposer:
|
if not skip_proposer:
|
||||||
# We prepare the prefill hidden states here so that there no
|
# We prepare the prefill hidden states here so that there no
|
||||||
@ -752,13 +762,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
]
|
]
|
||||||
if len(non_spec_indices):
|
if len(non_spec_indices):
|
||||||
all_hidden_states = proposal_scores.hidden_states
|
all_hidden_states = proposal_scores.hidden_states
|
||||||
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
|
|
||||||
if all_hidden_states is not None:
|
if all_hidden_states is not None:
|
||||||
prefill_hidden_states = all_hidden_states[non_spec_indices]
|
prefill_hidden_states = all_hidden_states[non_spec_indices]
|
||||||
execute_model_req.previous_hidden_states = \
|
execute_model_req.previous_hidden_states = \
|
||||||
prepare_prefill_hidden_states(prefill_hidden_states)
|
prepare_prefill_hidden_states(prefill_hidden_states)
|
||||||
# Sync proposer KV cache for prefills.
|
# Sync proposer KV cache for prefills.
|
||||||
prefill_req = execute_model_req.clone(non_spec_seqs)
|
prefill_req = execute_model_req.clone(non_spec_seqs)
|
||||||
|
# TODO avoid sampling here?
|
||||||
self.proposer_worker.execute_model(prefill_req)
|
self.proposer_worker.execute_model(prefill_req)
|
||||||
|
|
||||||
with Timer() as verification_timer:
|
with Timer() as verification_timer:
|
||||||
@ -774,6 +784,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
accepted_token_ids,
|
accepted_token_ids,
|
||||||
target_logprobs=target_logprobs,
|
target_logprobs=target_logprobs,
|
||||||
|
prompt_logprobs=proposal_scores.prompt_logprobs
|
||||||
|
if not self._disable_logprobs else None,
|
||||||
k=execute_model_req.num_lookahead_slots,
|
k=execute_model_req.num_lookahead_slots,
|
||||||
stage_times=stage_times)
|
stage_times=stage_times)
|
||||||
|
|
||||||
@ -845,19 +857,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# metadata.
|
# metadata.
|
||||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||||
|
|
||||||
|
# B x K+1 x D
|
||||||
hidden_states = proposal_scores.hidden_states
|
hidden_states = proposal_scores.hidden_states
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
|
# Only get terminal hidden states for next step
|
||||||
|
terminal_metadata = [
|
||||||
|
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||||
|
]
|
||||||
|
|
||||||
# Contract hidden states based on accepted tokens
|
# Contract hidden states based on accepted tokens
|
||||||
hs_size = hidden_states.shape[-1]
|
hs_size = hidden_states.shape[-1]
|
||||||
|
|
||||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
# Drop non-terminal prefill chunks hidden states.
|
||||||
|
hidden_states = hidden_states[
|
||||||
|
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||||
|
accepted_index = accepted_index[
|
||||||
|
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||||
|
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||||
|
terminal_metadata)
|
||||||
|
index = accepted_index[:, None, None].expand(-1, 1,
|
||||||
|
hs_size) # b x 1 x d
|
||||||
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||||
# Store hidden states from target model for subsequent decode step
|
# Store hidden states from target model for subsequent decode step
|
||||||
self.previous_hidden_states = HiddenStates(
|
self.previous_hidden_states = HiddenStates(
|
||||||
hidden_states, seq_group_metadata_list,
|
hidden_states, terminal_metadata,
|
||||||
second_last_token_hidden_states)
|
second_last_token_hidden_states)
|
||||||
return accepted_token_ids, logprobs
|
return accepted_token_ids, logprobs
|
||||||
|
|
||||||
@ -866,6 +891,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||||
|
prompt_logprobs: Optional[
|
||||||
|
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
|
||||||
k: int,
|
k: int,
|
||||||
stage_times: Tuple[float, float, float],
|
stage_times: Tuple[float, float, float],
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
@ -909,15 +936,89 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
# Construct the output on a per-step, per-sequence basis.
|
# Construct the output on a per-step, per-sequence basis.
|
||||||
# Non-terminal prefill chunks will end up here as rows with just -1s
|
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
|
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
|
||||||
|
# terminal chunks will only have one generated token at time 0.
|
||||||
sampler_output_list: List[SamplerOutput] = []
|
sampler_output_list: List[SamplerOutput] = []
|
||||||
|
|
||||||
|
# Prefills are not multi-step (return at most 1 token), in order to
|
||||||
|
# avoid padding or repetition to fit decodes, we separate them.
|
||||||
|
for i, sg in enumerate(seq_group_metadata_list):
|
||||||
|
if not sg.is_prompt:
|
||||||
|
# Requests are ordered as prefills|decodes=>no more prefills.
|
||||||
|
break
|
||||||
|
num_logprobs = num_logprobs_per_seq[i]
|
||||||
|
seq_kwargs = dict(token_id=-1,
|
||||||
|
token_id_logprob_rank=0,
|
||||||
|
token_id_logprob=-float('inf'),
|
||||||
|
topk_token_ids=[-1] * num_logprobs,
|
||||||
|
topk_logprobs=[-float('inf')] * num_logprobs,
|
||||||
|
seq_id=seq_ids[i])
|
||||||
|
# Terminal chunk, has token.
|
||||||
|
if sg.do_sample:
|
||||||
|
seq_kwargs.update(
|
||||||
|
dict(
|
||||||
|
token_id=accepted_token_ids[i][0].item(),
|
||||||
|
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||||
|
0][i],
|
||||||
|
token_id_logprob=accepted_token_id_logprobs_by_step[0]
|
||||||
|
[i],
|
||||||
|
topk_token_ids=topk_indices_by_step[0][i]
|
||||||
|
[:num_logprobs],
|
||||||
|
# output only so step is 0
|
||||||
|
topk_logprobs=topk_logprobs_by_step[0][i]
|
||||||
|
[:num_logprobs],
|
||||||
|
))
|
||||||
|
needs_plogs = (sg.sampling_params.prompt_logprobs
|
||||||
|
and sg.sampling_params.prompt_logprobs > 0)
|
||||||
|
plogs = None
|
||||||
|
if prompt_logprobs is not None:
|
||||||
|
# Even non-terminal prompt chunks can have logprobs here.
|
||||||
|
plogs = prompt_logprobs[i]
|
||||||
|
elif needs_plogs:
|
||||||
|
# Prompt logprobs are requested but `_disable_logprobs` is set.
|
||||||
|
seq_data = next(iter(sg.seq_data.values()))
|
||||||
|
# Get only the tokens in this chunk!
|
||||||
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
prompt_token_ids = prompt_token_ids[
|
||||||
|
seq_data.
|
||||||
|
_num_computed_tokens:seq_data._num_computed_tokens +
|
||||||
|
sg.token_chunk_size]
|
||||||
|
|
||||||
|
is_first_chunk = seq_data._num_computed_tokens == 0
|
||||||
|
# There's no prob generated for the first token in a sequence.
|
||||||
|
if is_first_chunk:
|
||||||
|
prompt_token_ids = prompt_token_ids[1:]
|
||||||
|
plogs = [
|
||||||
|
create_logprobs_output(
|
||||||
|
token_id=p_token_id,
|
||||||
|
token_id_logprob_rank=-1,
|
||||||
|
token_id_logprob=0.0,
|
||||||
|
topk_token_ids=[],
|
||||||
|
topk_logprobs=[],
|
||||||
|
) for p_token_id in prompt_token_ids
|
||||||
|
]
|
||||||
|
seq_kwargs.update(dict(prompt_logprobs=plogs))
|
||||||
|
|
||||||
|
sampler_output_list.append(
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=[create_sequence_group_output(
|
||||||
|
**seq_kwargs)])) # type: ignore
|
||||||
|
|
||||||
|
# Decodes, create one SamplerOutput per-step (at most K+1).
|
||||||
for step_index in range(num_steps):
|
for step_index in range(num_steps):
|
||||||
if all(token_id == -1
|
if all(token_id == -1 for sg, token_id in zip(
|
||||||
for token_id in accepted_token_ids_by_step[step_index]):
|
seq_group_metadata_list,
|
||||||
|
accepted_token_ids_by_step[step_index])
|
||||||
|
if not sg.is_prompt):
|
||||||
break
|
break
|
||||||
|
|
||||||
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
||||||
for sequence_index in range(batch_size):
|
for sequence_index in range(batch_size):
|
||||||
|
seq_meta = seq_group_metadata_list[sequence_index]
|
||||||
|
# Prompts already processed above.
|
||||||
|
if seq_meta.is_prompt:
|
||||||
|
continue
|
||||||
|
|
||||||
# Each sequence may have a different num_logprobs; retrieve it.
|
# Each sequence may have a different num_logprobs; retrieve it.
|
||||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||||
step_output_token_ids.append(
|
step_output_token_ids.append(
|
||||||
@ -952,6 +1053,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# This is periodic because the rejection sampler emits metrics
|
# This is periodic because the rejection sampler emits metrics
|
||||||
# periodically.
|
# periodically.
|
||||||
self._maybe_log_stage_times(*stage_times)
|
self._maybe_log_stage_times(*stage_times)
|
||||||
|
# First `n_prefills` entries will contain prefills SamplerOutput when
|
||||||
|
# chunked prefill is enabled, the rest is decodes in multi-step format.
|
||||||
return sampler_output_list
|
return sampler_output_list
|
||||||
|
|
||||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user