mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 04:57:54 +08:00
[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
parent
050f285ff6
commit
62b8aebc6f
@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Bonus tokens are currently disabled. Verify they're set to -1.
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/4212
|
||||||
|
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
||||||
|
|
||||||
if which_tokens_accepted == "all_tokens_accepted":
|
if which_tokens_accepted == "all_tokens_accepted":
|
||||||
# Expect all tokens to be equal to draft tokens.
|
# Expect all tokens to be equal to draft tokens.
|
||||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||||
|
|
||||||
# Expect all bonus tokens to be included.
|
# Expect all bonus tokens to be included.
|
||||||
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
|
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||||
elif which_tokens_accepted == "no_tokens_accepted":
|
elif which_tokens_accepted == "no_tokens_accepted":
|
||||||
# Expect first token to be equal to recovered tokens.
|
# Expect first token to be equal to recovered tokens.
|
||||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||||
@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||||
elif which_tokens_accepted == "some_tokens_accepted":
|
elif which_tokens_accepted == "some_tokens_accepted":
|
||||||
recovered_plus_bonus = torch.cat(
|
recovered_plus_bonus = torch.cat(
|
||||||
(recovered_token_ids, bonus_token_ids), dim=-1)
|
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||||
# Assert first rejected token is a recovered token or bonus token.
|
# Assert first rejected token is a recovered token or bonus token.
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||||
|
|||||||
@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
def mock_sample(probs, *args, **kwargs):
|
def mock_sample(probs, *args, **kwargs):
|
||||||
nonlocal sample_probs
|
nonlocal sample_probs
|
||||||
sample_probs = probs
|
sample_probs = probs
|
||||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||||
|
for prob in probs], None)
|
||||||
|
|
||||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
|
|||||||
0
tests/spec_decode/e2e/__init__.py
Normal file
0
tests/spec_decode/e2e/__init__.py
Normal file
@ -1,3 +1,5 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.conftest import cleanup
|
from tests.conftest import cleanup
|
||||||
@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
def baseline_llm_generator(request, common_llm_kwargs,
|
||||||
baseline_llm_kwargs, seed):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
seed):
|
||||||
|
return create_llm_generator("baseline", request, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, seed)
|
baseline_llm_kwargs, seed)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
test_llm_kwargs, seed):
|
test_llm_kwargs, seed):
|
||||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
return create_llm_generator("test", request, common_llm_kwargs,
|
||||||
test_llm_kwargs, seed)
|
per_test_common_llm_kwargs, test_llm_kwargs,
|
||||||
|
seed)
|
||||||
|
|
||||||
|
|
||||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||||
distinct_llm_kwargs, seed):
|
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
||||||
|
seed):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
**common_llm_kwargs,
|
**common_llm_kwargs,
|
||||||
**per_test_common_llm_kwargs,
|
**per_test_common_llm_kwargs,
|
||||||
**distinct_llm_kwargs,
|
**distinct_llm_kwargs,
|
||||||
}
|
}
|
||||||
|
test_name = request.node.name
|
||||||
|
|
||||||
def generator_inner():
|
def generator_inner():
|
||||||
|
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||||
llm = LLM(**kwargs)
|
llm = LLM(**kwargs)
|
||||||
|
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
del llm
|
del llm
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
||||||
for llm in generator_inner():
|
def generator_outer():
|
||||||
yield llm
|
for llm in generator_inner():
|
||||||
|
yield llm
|
||||||
|
del llm
|
||||||
|
|
||||||
|
return generator_outer
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_from_llm_generator(
|
||||||
|
llm_generator, prompts,
|
||||||
|
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
tokens = []
|
||||||
|
token_ids = []
|
||||||
|
for llm in llm_generator():
|
||||||
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
|
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||||
|
tokens = [output.outputs[0].text for output in outputs]
|
||||||
del llm
|
del llm
|
||||||
|
|
||||||
|
return tokens, token_ids
|
||||||
|
|||||||
169
tests/spec_decode/e2e/test_compatibility.py
Normal file
169
tests/spec_decode/e2e/test_compatibility.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from .conftest import get_output_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
# Expect failure as spec decode not supported by
|
||||||
|
# Ray backend.
|
||||||
|
"worker_use_ray": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_xfail_ray(test_llm_generator):
|
||||||
|
"""Verify that speculative decoding with Ray fails.
|
||||||
|
"""
|
||||||
|
output_len = 128
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError,
|
||||||
|
match="Speculative decoding not yet supported for "):
|
||||||
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
|
sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"enable_chunked_prefill": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||||
|
"""Verify that speculative decoding with chunked prefill fails.
|
||||||
|
"""
|
||||||
|
output_len = 128
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Speculative decoding and chunked prefill"):
|
||||||
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
|
sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
# Speculative max model len > overridden max model len should raise.
|
||||||
|
"max_model_len": 128,
|
||||||
|
"speculative_max_model_len": 129,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Speculative max model len > draft max model len should raise.
|
||||||
|
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||||
|
"speculative_max_model_len": 2048 + 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Speculative max model len > target max model len should raise.
|
||||||
|
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
|
||||||
|
"speculative_max_model_len": 4096 + 1,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||||
|
"""Verify that speculative decoding validates speculative_max_model_len.
|
||||||
|
"""
|
||||||
|
output_len = 128
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||||
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
|
sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
|
||||||
|
"""Verify that speculative decoding with block manager v1 fails.
|
||||||
|
"""
|
||||||
|
output_len = 128
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Speculative decoding requires usage of the V2"):
|
||||||
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
|
sampling_params)
|
||||||
@ -1,11 +1,42 @@
|
|||||||
|
"""The tests in this file verify end-to-end speculative decoding correctness.
|
||||||
|
|
||||||
|
This docstring details important information on the testing methodology.
|
||||||
|
|
||||||
|
Most of the tests rely on "greedy equality", where we expect the output of
|
||||||
|
speculative decoding on a sequence to exactly match the output of normal non-
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Since speculative decoding with rejection sampling guarantees that the output
|
||||||
|
distribution matches the target model's output distribution (up to hardware
|
||||||
|
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||||
|
equality. This gives us good coverage of temp=0.
|
||||||
|
|
||||||
|
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||||
|
output distribution is the same with spec decode vs. no spec decode (this would
|
||||||
|
be prohibitively expensive to run with a real model).
|
||||||
|
|
||||||
|
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||||
|
distributions of the target model and proposal model be deterministic given the
|
||||||
|
same input. vLLM largely guarantees this.
|
||||||
|
|
||||||
|
@cadedaniel has seen cases where the output probabilities of a draft/target
|
||||||
|
model change slightly with certain batch sizes or prompts, even with Torch
|
||||||
|
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
|
||||||
|
determinism in on-device batched operations, a bug in vLLM's spec decode
|
||||||
|
implementation, or the "hardware numerics" limitations. Either way, rejection
|
||||||
|
sampling ensures the output distribution matches the target model, but it breaks
|
||||||
|
greedy-equality tests for those batch sizes/prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from .conftest import get_output_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
@ -14,9 +45,6 @@ from vllm import SamplingParams
|
|||||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip real loading for fast test.
|
|
||||||
"load_format": "dummy",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
@ -31,22 +59,15 @@ from vllm import SamplingParams
|
|||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
# Verify the detokenizer assertions in the test work when spec
|
||||||
"num_speculative_tokens": 1,
|
# decode is disabled.
|
||||||
},
|
|
||||||
{
|
|
||||||
# No spec decode.
|
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("batch_size", [1])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
# NOTE: We should run more permutations of this test (more BS, more seeds). But
|
|
||||||
# because our spec decode generates gibberish token ids, the likelihood of
|
|
||||||
# emitting an invalid token combination is nontrivial. This causes divergence in
|
|
||||||
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
|
|
||||||
# start" bytes are emitted.
|
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||||
|
batch_size: int):
|
||||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||||
detokenization matches HF transformers.
|
detokenization matches HF transformers.
|
||||||
@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
|||||||
max_tokens=output_len,
|
max_tokens=output_len,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
skip_special_tokens=True,
|
|
||||||
spaces_between_special_tokens=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||||
@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
|||||||
# Expect a generation for each prompt in the batch.
|
# Expect a generation for each prompt in the batch.
|
||||||
assert len(batch_token_ids) == len(prompts)
|
assert len(batch_token_ids) == len(prompts)
|
||||||
|
|
||||||
# Expect each generation to have expected number of tokens (note
|
# Expect each generation to have expected number of tokens (note ignore_eos
|
||||||
# ignore_eos=True).
|
# is True).
|
||||||
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
|
assert [len(token_ids)
|
||||||
|
for token_ids in batch_token_ids] == ([output_len] * batch_size)
|
||||||
|
|
||||||
# Expect detokenized string to match.
|
# Expect detokenized string to match.
|
||||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||||
@ -92,13 +112,293 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
# Use a small model for a fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"model": "JackFram/llama-68m",
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[
|
||||||
|
# Try two different tiny base models.
|
||||||
|
# Note that one is equal to the draft model, another isn't.
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use long output len for the small model test.
|
||||||
|
1536,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality on a tiny model with batch size of one.
|
||||||
|
|
||||||
# Skip real loading for fast test.
|
Since this test is cheaper than other e2e correctness tests, we generate
|
||||||
"load_format": "dummy",
|
with a higher output_len.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[
|
||||||
|
# Try two different tiny base models.
|
||||||
|
# Note that one is equal to the draft model, another isn't.
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
256,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [64])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality on a tiny model and large batch size.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[
|
||||||
|
# Try two different tiny base models.
|
||||||
|
# Note that one is equal to the draft model, another isn't.
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("max_output_len", [
|
||||||
|
256,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
max_output_len: int):
|
||||||
|
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||||
|
sampling respects the EOS token.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len,
|
||||||
|
force_output_len=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# A "real" model (not tiny).
|
||||||
|
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use decently long output len for a high quality test.
|
||||||
|
256,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||||
|
separate from large BS tests to make identifying the source of bugs easier.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# A "real" model (not tiny).
|
||||||
|
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [32])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
64,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||||
|
This is the closest test to a real production workload.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"block_size": 8,
|
||||||
|
# 2 for small prompt, 256//8 for generated.
|
||||||
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
256,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -109,43 +409,189 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"per_test_common_llm_kwargs",
|
"per_test_common_llm_kwargs",
|
||||||
[
|
[
|
||||||
|
# As of this writing, vLLM only compiles with these 3 block sizes by
|
||||||
|
# default.
|
||||||
{
|
{
|
||||||
# Expect failure as spec decode not supported by
|
"block_size": 8,
|
||||||
# Ray backend.
|
},
|
||||||
"worker_use_ray": True,
|
{
|
||||||
|
"block_size": 16,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"block_size": 32,
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_xfail(test_llm_generator):
|
def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||||
"""Verify that speculative decoding with Ray fails.
|
test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality over different block sizes.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
|
||||||
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
|
"speculative_max_model_len": 32,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# This must be a good bit larger than speculative_max_model_len so that
|
||||||
|
# we can test the case where all seqs are skipped, but still small to
|
||||||
|
# ensure fast test.
|
||||||
|
64,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||||
|
We do this by setting the max model len of the draft model to an
|
||||||
|
artificially low value, such that when the sequences grow beyond it, they
|
||||||
|
are skipped in speculative decoding.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
}
|
||||||
|
# Try a range of common k, as well as large speculation.
|
||||||
|
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify that speculative decoding produces exact equality to without spec
|
||||||
|
decode with many different values of k.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len,
|
||||||
|
force_output_len: bool,
|
||||||
|
print_tokens: bool = False):
|
||||||
|
"""Helper method that compares the outputs of both the baseline LLM and
|
||||||
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
|
the same when temperature is zero.
|
||||||
"""
|
"""
|
||||||
output_len = 128
|
|
||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
"San Francisco is know for its",
|
||||||
|
"Facebook was created in 2004 by",
|
||||||
|
"Curious George is a",
|
||||||
|
"Python 3.11 brings improvements to its",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||||
|
|
||||||
|
# If the test requires that we generated max_output_len tokens, then set the
|
||||||
|
# sampling params to ignore eos token.
|
||||||
|
ignore_eos = force_output_len
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
max_tokens=output_len,
|
max_tokens=max_output_len,
|
||||||
ignore_eos=True,
|
ignore_eos=ignore_eos,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(AssertionError,
|
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||||
match="Speculative decoding not yet supported for "):
|
test_llm_generator, prompts, sampling_params)
|
||||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
|
||||||
sampling_params)
|
|
||||||
|
|
||||||
|
(baseline_batch_tokens,
|
||||||
|
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||||
|
baseline_llm_generator, prompts, sampling_params)
|
||||||
|
|
||||||
def get_output_from_llm_generator(
|
assert len(baseline_batch_token_ids) == len(prompts)
|
||||||
llm_generator, prompts,
|
assert len(spec_batch_token_ids) == len(prompts)
|
||||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
|
||||||
for llm in llm_generator:
|
|
||||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
||||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
|
||||||
tokens = [output.outputs[0].text for output in outputs]
|
|
||||||
del llm
|
|
||||||
|
|
||||||
return tokens, token_ids
|
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||||
|
spec_tokens) in enumerate(
|
||||||
|
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||||
|
spec_batch_token_ids, spec_batch_tokens)):
|
||||||
|
if print_tokens:
|
||||||
|
print(f'{i=} {baseline_tokens=}')
|
||||||
|
print(f'{i=} {spec_tokens=}')
|
||||||
|
print(f'{i=} {baseline_token_ids=}')
|
||||||
|
print(f'{i=} {spec_token_ids=}')
|
||||||
|
assert baseline_token_ids == spec_token_ids
|
||||||
|
|||||||
@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
|||||||
num_draft_tokens = 0
|
num_draft_tokens = 0
|
||||||
k = 5
|
k = 5
|
||||||
|
|
||||||
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
|
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||||
num_draft_tokens, k)
|
num_draft_tokens, k)
|
||||||
|
|
||||||
rej_sampler = MagicMock()
|
rej_sampler = MagicMock()
|
||||||
@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
|||||||
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
|
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
|
||||||
num_draft_tokens)
|
num_draft_tokens)
|
||||||
assert (metrics.system_efficiency == num_emitted_tokens /
|
assert (metrics.system_efficiency == num_emitted_tokens /
|
||||||
num_possible_tokens)
|
max_num_emitted_tokens)
|
||||||
else:
|
else:
|
||||||
assert math.isnan(metrics.draft_acceptance_rate)
|
assert math.isnan(metrics.draft_acceptance_rate)
|
||||||
assert math.isnan(metrics.system_efficiency)
|
assert math.isnan(metrics.system_efficiency)
|
||||||
|
|||||||
@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
|
|||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|
||||||
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
|
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
|
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||||
|
|
||||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
"""Verify SpecDecodeWorker calls the target model with correct
|
"""Verify SpecDecodeWorker calls the target model with correct
|
||||||
inputs. Everything else is mocked out.
|
inputs. Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker(use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
"""
|
"""
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
|
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||||
target_worker = mock_worker(vocab_size=vocab_size)
|
vocab_size=vocab_size,
|
||||||
|
use_spec=False)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
num_lookahead_slots=k)
|
num_lookahead_slots=k)
|
||||||
|
|
||||||
assert len(rejection_sampler.call_args_list) == 1
|
assert len(rejection_sampler.call_args_list) == 1
|
||||||
args, _ = rejection_sampler.call_args_list[0]
|
_, kwargs = rejection_sampler.call_args_list[0]
|
||||||
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
|
actual = SimpleNamespace(**kwargs)
|
||||||
actual_proposal_token_ids) = args
|
|
||||||
|
|
||||||
assert torch.equal(actual_bonus_token_ids,
|
assert torch.equal(actual.bonus_token_ids,
|
||||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
actual_proposal_scores,
|
actual.target_probs,
|
||||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||||
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
|
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||||
assert torch.equal(actual_proposal_probs, proposal_probs)
|
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
"""
|
"""
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
|
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||||
target_worker = mock_worker(vocab_size=vocab_size)
|
vocab_size=vocab_size,
|
||||||
|
use_spec=False)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
"""
|
"""
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
|
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||||
target_worker = mock_worker(vocab_size=vocab_size)
|
vocab_size=vocab_size,
|
||||||
|
use_spec=False)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
@ -500,8 +506,8 @@ def test_init_device():
|
|||||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||||
well as other GPU initialization.
|
well as other GPU initialization.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker(use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|||||||
@ -63,11 +63,14 @@ def create_execute_model_data(
|
|||||||
def mock_worker(cls=None,
|
def mock_worker(cls=None,
|
||||||
vocab_size: int = 30_000,
|
vocab_size: int = 30_000,
|
||||||
max_model_len: int = 2048,
|
max_model_len: int = 2048,
|
||||||
rank: int = 0) -> MagicMock:
|
rank: int = 0,
|
||||||
|
use_spec: bool = True) -> MagicMock:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
cls = Worker
|
cls = Worker
|
||||||
|
|
||||||
worker = MagicMock(spec=cls)
|
spec = cls if use_spec else None
|
||||||
|
|
||||||
|
worker = MagicMock(spec=spec)
|
||||||
worker.vocab_size = vocab_size
|
worker.vocab_size = vocab_size
|
||||||
worker.max_model_len = max_model_len
|
worker.max_model_len = max_model_len
|
||||||
worker.rank = rank
|
worker.rank = rank
|
||||||
|
|||||||
@ -655,6 +655,9 @@ class SpeculativeConfig:
|
|||||||
target_dtype: str,
|
target_dtype: str,
|
||||||
speculative_model: Optional[str],
|
speculative_model: Optional[str],
|
||||||
num_speculative_tokens: Optional[int],
|
num_speculative_tokens: Optional[int],
|
||||||
|
speculative_max_model_len: Optional[int],
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
use_v2_block_manager: bool,
|
||||||
) -> Optional["SpeculativeConfig"]:
|
) -> Optional["SpeculativeConfig"]:
|
||||||
"""Create a SpeculativeConfig if possible, else return None.
|
"""Create a SpeculativeConfig if possible, else return None.
|
||||||
|
|
||||||
@ -672,6 +675,15 @@ class SpeculativeConfig:
|
|||||||
model, if provided.
|
model, if provided.
|
||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
num_speculative_tokens (Optional[int]): The number of speculative
|
||||||
tokens, if provided.
|
tokens, if provided.
|
||||||
|
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||||
|
the speculative model. Used when testing the ability to skip
|
||||||
|
speculation for some sequences.
|
||||||
|
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||||
|
chunked prefill or not. Used for raising an error since its not
|
||||||
|
yet compatible with spec decode.
|
||||||
|
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||||
|
v2 block manager or not. Used for raising an error since the v2
|
||||||
|
block manager is required with spec decode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||||
@ -690,12 +702,21 @@ class SpeculativeConfig:
|
|||||||
assert (speculative_model is not None
|
assert (speculative_model is not None
|
||||||
and num_speculative_tokens is not None)
|
and num_speculative_tokens is not None)
|
||||||
|
|
||||||
|
if enable_chunked_prefill:
|
||||||
|
raise ValueError(
|
||||||
|
"Speculative decoding and chunked prefill are "
|
||||||
|
f"currently mutually exclusive ({enable_chunked_prefill=}).")
|
||||||
|
|
||||||
|
if not use_v2_block_manager:
|
||||||
|
raise ValueError(
|
||||||
|
"Speculative decoding requires usage of the V2 "
|
||||||
|
"block manager. Enable it with --use-v2-block-manager.")
|
||||||
|
|
||||||
# TODO: The user should be able to specify revision/quantization/max
|
# TODO: The user should be able to specify revision/quantization/max
|
||||||
# model len for the draft model. It is not currently supported.
|
# model len for the draft model. It is not currently supported.
|
||||||
draft_revision = None
|
draft_revision = None
|
||||||
draft_code_revision = None
|
draft_code_revision = None
|
||||||
draft_quantization = None
|
draft_quantization = None
|
||||||
draft_max_model_len = None
|
|
||||||
|
|
||||||
draft_model_config = ModelConfig(
|
draft_model_config = ModelConfig(
|
||||||
model=speculative_model,
|
model=speculative_model,
|
||||||
@ -707,7 +728,7 @@ class SpeculativeConfig:
|
|||||||
revision=draft_revision,
|
revision=draft_revision,
|
||||||
code_revision=draft_code_revision,
|
code_revision=draft_code_revision,
|
||||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||||
max_model_len=draft_max_model_len,
|
max_model_len=None,
|
||||||
quantization=draft_quantization,
|
quantization=draft_quantization,
|
||||||
enforce_eager=target_model_config.enforce_eager,
|
enforce_eager=target_model_config.enforce_eager,
|
||||||
max_context_len_to_capture=target_model_config.
|
max_context_len_to_capture=target_model_config.
|
||||||
@ -715,6 +736,13 @@ class SpeculativeConfig:
|
|||||||
max_logprobs=target_model_config.max_logprobs,
|
max_logprobs=target_model_config.max_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
draft_model_config.max_model_len = (
|
||||||
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
|
speculative_max_model_len,
|
||||||
|
draft_model_config.max_model_len,
|
||||||
|
target_model_config.max_model_len,
|
||||||
|
))
|
||||||
|
|
||||||
draft_parallel_config = (
|
draft_parallel_config = (
|
||||||
SpeculativeConfig.create_draft_parallel_config(
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
target_parallel_config))
|
target_parallel_config))
|
||||||
@ -725,6 +753,41 @@ class SpeculativeConfig:
|
|||||||
num_speculative_tokens,
|
num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_override_draft_max_model_len(
|
||||||
|
speculative_max_model_len: Optional[int],
|
||||||
|
draft_max_model_len: int,
|
||||||
|
target_max_model_len: int,
|
||||||
|
) -> int:
|
||||||
|
"""Determine the max sequence len for the draft model. This is usually
|
||||||
|
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||||
|
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||||
|
if it is specified.
|
||||||
|
|
||||||
|
This is necessary so that sequences do not exceed the capacity of the
|
||||||
|
draft model or the target model.
|
||||||
|
|
||||||
|
speculative_max_model_len is mainly used for testing that sequences can
|
||||||
|
skip speculation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if speculative_max_model_len is not None:
|
||||||
|
|
||||||
|
if speculative_max_model_len > draft_max_model_len:
|
||||||
|
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||||
|
f"larger than {draft_max_model_len=}")
|
||||||
|
|
||||||
|
if speculative_max_model_len > target_max_model_len:
|
||||||
|
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||||
|
f"larger than {target_max_model_len=}")
|
||||||
|
|
||||||
|
return speculative_max_model_len
|
||||||
|
|
||||||
|
return min(
|
||||||
|
draft_max_model_len,
|
||||||
|
target_max_model_len,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_draft_parallel_config(
|
def create_draft_parallel_config(
|
||||||
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
||||||
|
|||||||
@ -73,6 +73,7 @@ class EngineArgs:
|
|||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
|
speculative_max_model_len: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -237,7 +238,7 @@ class EngineArgs:
|
|||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32, 128],
|
choices=[8, 16, 32],
|
||||||
help='Token block size for contiguous chunks of '
|
help='Token block size for contiguous chunks of '
|
||||||
'tokens.')
|
'tokens.')
|
||||||
|
|
||||||
@ -420,17 +421,25 @@ class EngineArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--speculative-model',
|
'--speculative-model',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=EngineArgs.speculative_model,
|
||||||
help=
|
help=
|
||||||
'The name of the draft model to be used in speculative decoding.')
|
'The name of the draft model to be used in speculative decoding.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num-speculative-tokens',
|
'--num-speculative-tokens',
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=EngineArgs.num_speculative_tokens,
|
||||||
help='The number of speculative tokens to sample from '
|
help='The number of speculative tokens to sample from '
|
||||||
'the draft model in speculative decoding.')
|
'the draft model in speculative decoding.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--speculative-max-model-len',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.speculative_max_model_len,
|
||||||
|
help='The maximum sequence length supported by the '
|
||||||
|
'draft model. Sequences over this length will skip '
|
||||||
|
'speculation.')
|
||||||
|
|
||||||
parser.add_argument('--model-loader-extra-config',
|
parser.add_argument('--model-loader-extra-config',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.model_loader_extra_config,
|
default=EngineArgs.model_loader_extra_config,
|
||||||
@ -481,6 +490,9 @@ class EngineArgs:
|
|||||||
target_dtype=self.dtype,
|
target_dtype=self.dtype,
|
||||||
speculative_model=self.speculative_model,
|
speculative_model=self.speculative_model,
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||||
SequenceGroup)
|
SequenceGroup, SequenceStage)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
@ -480,9 +480,12 @@ class LLMEngine:
|
|||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.update_num_computed_tokens(
|
seq_group.update_num_computed_tokens(
|
||||||
scheduled_seq_group.token_chunk_size)
|
scheduled_seq_group.token_chunk_size)
|
||||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
|
||||||
# We don't need to process outputs in that case.
|
# If all sequences in the sequence group are in DECODE, then we can
|
||||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
# process the output tokens. Otherwise, they are (chunked) prefill
|
||||||
|
# samples and should not be processed.
|
||||||
|
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
||||||
|
if all(stage == SequenceStage.DECODE for stage in stages):
|
||||||
self.output_processor.process_outputs(seq_group, outputs)
|
self.output_processor.process_outputs(seq_group, outputs)
|
||||||
|
|
||||||
# Free the finished sequence groups.
|
# Free the finished sequence groups.
|
||||||
@ -569,7 +572,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
self.stat_logger.log(
|
||||||
|
self._get_stats(scheduler_outputs, model_output=output))
|
||||||
|
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
@ -578,9 +582,18 @@ class LLMEngine:
|
|||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
||||||
|
|
||||||
def _get_stats(self,
|
def _get_stats(
|
||||||
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
self,
|
||||||
"""Get Stats to be Logged to Prometheus."""
|
scheduler_outputs: Optional[SchedulerOutputs],
|
||||||
|
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
|
||||||
|
"""Get Stats to be Logged to Prometheus.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_outputs: Optional, used to populate metrics related to
|
||||||
|
the scheduled batch,
|
||||||
|
model_output: Optional, used to emit speculative decoding metrics
|
||||||
|
which are created by the workers.
|
||||||
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# KV Cache Usage in %.
|
# KV Cache Usage in %.
|
||||||
@ -637,6 +650,14 @@ class LLMEngine:
|
|||||||
time_to_first_tokens = time_last_iters if prompt_run else []
|
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||||
time_per_output_tokens = [] if prompt_run else time_last_iters
|
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||||
|
|
||||||
|
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||||
|
# sampler output.
|
||||||
|
if model_output and (model_output[0].spec_decode_worker_metrics
|
||||||
|
is not None):
|
||||||
|
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||||
|
else:
|
||||||
|
spec_decode_metrics = None
|
||||||
|
|
||||||
return Stats(
|
return Stats(
|
||||||
now=now,
|
now=now,
|
||||||
num_running=num_running,
|
num_running=num_running,
|
||||||
@ -649,6 +670,7 @@ class LLMEngine:
|
|||||||
time_to_first_tokens=time_to_first_tokens,
|
time_to_first_tokens=time_to_first_tokens,
|
||||||
time_per_output_tokens=time_per_output_tokens,
|
time_per_output_tokens=time_per_output_tokens,
|
||||||
time_e2e_requests=time_e2e_requests,
|
time_e2e_requests=time_e2e_requests,
|
||||||
|
spec_decode_metrics=spec_decode_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Protocol
|
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||||
@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
disable_created_metrics()
|
disable_created_metrics()
|
||||||
@ -118,6 +121,8 @@ class Stats:
|
|||||||
time_per_output_tokens: List[float]
|
time_per_output_tokens: List[float]
|
||||||
time_e2e_requests: List[float]
|
time_e2e_requests: List[float]
|
||||||
|
|
||||||
|
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
|
||||||
class SupportsMetricsInfo(Protocol):
|
class SupportsMetricsInfo(Protocol):
|
||||||
|
|
||||||
@ -235,3 +240,19 @@ class StatLogger:
|
|||||||
self.num_prompt_tokens = []
|
self.num_prompt_tokens = []
|
||||||
self.num_generation_tokens = []
|
self.num_generation_tokens = []
|
||||||
self.last_local_log = stats.now
|
self.last_local_log = stats.now
|
||||||
|
|
||||||
|
if stats.spec_decode_metrics is not None:
|
||||||
|
logger.info(
|
||||||
|
self._format_spec_decode_metrics_str(
|
||||||
|
stats.spec_decode_metrics))
|
||||||
|
|
||||||
|
def _format_spec_decode_metrics_str(
|
||||||
|
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||||
|
|
||||||
|
return ("Speculative metrics: "
|
||||||
|
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||||
|
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||||
|
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||||
|
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||||
|
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
|
||||||
|
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
scheduler_config=self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
device_config=self.device_config,
|
device_config=self.device_config,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
|
# TODO allow draft-model specific load config.
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
|
|||||||
@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
|
|||||||
recovered_probs = self._get_recovered_probs(
|
recovered_probs = self._get_recovered_probs(
|
||||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||||
|
|
||||||
|
# NOTE: the recovered_probs are overwritten by this method.
|
||||||
recovered_token_ids = _multinomial(recovered_probs,
|
recovered_token_ids = _multinomial(recovered_probs,
|
||||||
num_samples=1).reshape(
|
num_samples=1).reshape(
|
||||||
batch_size, k)
|
batch_size, k)
|
||||||
@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
|
|||||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||||
bonus_token_ids, -1)
|
bonus_token_ids, -1)
|
||||||
|
|
||||||
|
# We disable bonus tokens because it causes corrupt KV cache for
|
||||||
|
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||||
|
# the bonus token in the proposer. The following issue tracks the fix.
|
||||||
|
# https://github.com/vllm-project/vllm/issues/4212
|
||||||
|
output_with_bonus_tokens[:, -1] = -1
|
||||||
|
|
||||||
# Fill the recovered token ids.
|
# Fill the recovered token ids.
|
||||||
output.mul_(~after_false_mask).add_(
|
output.mul_(~after_false_mask).add_(
|
||||||
recovered_token_ids.mul(after_false_mask))
|
recovered_token_ids.mul(after_false_mask))
|
||||||
|
|||||||
@ -35,6 +35,14 @@ class Sampler(nn.Module):
|
|||||||
in logits for each token in the input prompt.
|
in logits for each token in the input prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Whether or not the SamplerOutput should have on-device tensors
|
||||||
|
# containing the sampled token ids and probabilities. This is used by
|
||||||
|
# speculative decoding.
|
||||||
|
self.include_gpu_probs_tensor = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
@ -79,13 +87,45 @@ class Sampler(nn.Module):
|
|||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
sample_results = _sample(probs, logprobs, sampling_metadata,
|
sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||||
sampling_tensors)
|
probs,
|
||||||
|
logprobs,
|
||||||
|
sampling_metadata,
|
||||||
|
sampling_tensors,
|
||||||
|
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||||
|
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.include_gpu_probs_tensor:
|
||||||
|
assert maybe_sampled_tokens_tensor is not None
|
||||||
|
sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
||||||
|
on_device_tensors = (probs, sampled_tokens_tensor)
|
||||||
|
else:
|
||||||
|
on_device_tensors = None
|
||||||
|
|
||||||
# Get the logprobs query results.
|
# Get the logprobs query results.
|
||||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||||
logprobs, sampling_metadata, sample_results)
|
logprobs, sampling_metadata, sample_results)
|
||||||
return _build_sampler_output(sample_results, sampling_metadata,
|
return _build_sampler_output(sample_results,
|
||||||
prompt_logprobs, sample_logprobs)
|
sampling_metadata,
|
||||||
|
prompt_logprobs,
|
||||||
|
sample_logprobs,
|
||||||
|
on_device_tensors=on_device_tensors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _should_modify_greedy_probs_inplace(self) -> bool:
|
||||||
|
"""Whether or not the sampler should modify the probability distribution
|
||||||
|
of greedily-sampled tokens such that multinomial sampling would sample
|
||||||
|
the greedily-sampled token.
|
||||||
|
|
||||||
|
In other words, if True then we set the probability of the greedily-
|
||||||
|
sampled token to 1.
|
||||||
|
|
||||||
|
This is used by speculative decoding, which requires that the sampling
|
||||||
|
method be encoded into the probability distribution.
|
||||||
|
"""
|
||||||
|
# Modify greedy probs if include_gpu_probs_tensor is set.
|
||||||
|
return self.include_gpu_probs_tensor
|
||||||
|
|
||||||
|
|
||||||
def _get_bin_counts_and_mask(
|
def _get_bin_counts_and_mask(
|
||||||
@ -359,7 +399,9 @@ def _sample_with_torch(
|
|||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
include_gpu_probs_tensor: bool,
|
||||||
|
modify_greedy_probs: bool,
|
||||||
|
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
@ -371,6 +413,15 @@ def _sample_with_torch(
|
|||||||
sample_metadata = {}
|
sample_metadata = {}
|
||||||
multinomial_samples = {}
|
multinomial_samples = {}
|
||||||
|
|
||||||
|
# Create output tensor for sampled token ids.
|
||||||
|
if include_gpu_probs_tensor:
|
||||||
|
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
||||||
|
1,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logprobs.device)
|
||||||
|
else:
|
||||||
|
sampled_token_ids_tensor = None
|
||||||
|
|
||||||
# Counterintiutively, having two loops here is actually faster.
|
# Counterintiutively, having two loops here is actually faster.
|
||||||
# The first loop can run without waiting on GPU<->CPU sync.
|
# The first loop can run without waiting on GPU<->CPU sync.
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
@ -383,9 +434,25 @@ def _sample_with_torch(
|
|||||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||||
is_prompts, sample_indices)
|
is_prompts, sample_indices)
|
||||||
|
long_sample_indices = sample_indices.long()
|
||||||
|
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
|
if include_gpu_probs_tensor:
|
||||||
|
# Store sampled tokens in output tensor.
|
||||||
|
sampled_token_ids_tensor[
|
||||||
|
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
||||||
|
|
||||||
|
if modify_greedy_probs:
|
||||||
|
# If required, modify the probabilities such that sampling from
|
||||||
|
# the modified distribution would always sample the argmax
|
||||||
|
# token id.
|
||||||
|
_modify_greedy_probs_inplace(logprobs, probs,
|
||||||
|
long_sample_indices,
|
||||||
|
greedy_samples)
|
||||||
|
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of_in_batch = 1
|
max_best_of_in_batch = 1
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
@ -397,15 +464,23 @@ def _sample_with_torch(
|
|||||||
"seq_groups": seq_groups,
|
"seq_groups": seq_groups,
|
||||||
"generators": sampling_metadata.generators,
|
"generators": sampling_metadata.generators,
|
||||||
}
|
}
|
||||||
|
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
probs[sample_indices.long()], max_best_of_in_batch,
|
probs[long_sample_indices], max_best_of_in_batch,
|
||||||
**seeded_args)
|
**seeded_args)
|
||||||
|
|
||||||
|
if include_gpu_probs_tensor:
|
||||||
|
# Store sampled tokens in output tensor.
|
||||||
|
sampled_token_ids_tensor[
|
||||||
|
long_sample_indices] = multinomial_samples[sampling_type]
|
||||||
|
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
beam_search_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
|
|
||||||
# GPU<->CPU sync happens in the loop below.
|
# GPU<->CPU sync happens in the loop below.
|
||||||
|
# This also converts the sample output to Python objects.
|
||||||
|
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
if sampling_type not in sample_metadata:
|
if sampling_type not in sample_metadata:
|
||||||
@ -427,7 +502,7 @@ def _sample_with_torch(
|
|||||||
sample_results_dict[i]
|
sample_results_dict[i]
|
||||||
for i in range(len(sampling_metadata.seq_groups))
|
for i in range(len(sampling_metadata.seq_groups))
|
||||||
]
|
]
|
||||||
return sample_results
|
return sample_results, sampled_token_ids_tensor
|
||||||
|
|
||||||
|
|
||||||
def _sample_with_triton_kernel(
|
def _sample_with_triton_kernel(
|
||||||
@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor, logprobs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||||
sampling_metadata: SamplingMetadata,
|
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||||
sampling_tensors: SamplingTensors,
|
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
return _sample_with_torch(
|
||||||
return _sample_with_torch(probs, logprobs, sampling_metadata)
|
probs,
|
||||||
|
logprobs,
|
||||||
|
sampling_metadata,
|
||||||
|
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Enable once Triton kernel & associated code is faster.
|
# TODO: Enable once Triton kernel & associated code is faster.
|
||||||
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
||||||
@ -680,12 +760,73 @@ def _get_logprobs(
|
|||||||
return result_prompt_logprobs, result_sample_logprobs
|
return result_prompt_logprobs, result_sample_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||||
|
sample_indices: torch.Tensor,
|
||||||
|
greedy_samples: torch.Tensor) -> None:
|
||||||
|
"""Modify the probability distributions of the greedily-sampled tokens such
|
||||||
|
that each sampled token has a "probability" of 1.0. This is required by
|
||||||
|
speculative decoding, which depends on the sampling method being encoded
|
||||||
|
within the probability distribution for correctness.
|
||||||
|
|
||||||
|
# Why do we only need to do this for greedy sampling?
|
||||||
|
|
||||||
|
vLLM's sampler performs the following steps for greedy or multinomial
|
||||||
|
(random) sampling:
|
||||||
|
1. Get logits from model.
|
||||||
|
2. Modify logits according to per-sequence sampling parameters.
|
||||||
|
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
||||||
|
according to their frequency, etc.
|
||||||
|
3. Sample a token.
|
||||||
|
- Random sampling simply samples from the modified probability
|
||||||
|
distribution.
|
||||||
|
- Greedy sampling performs `argmax` to obtain the token with the
|
||||||
|
highest likelihood.
|
||||||
|
|
||||||
|
Ignoring greedy sampling for a moment, we find that the computed probability
|
||||||
|
distribution has the following property: we can sample from it independently
|
||||||
|
and find that the token sampled by the Sampler has a frequency corresponding
|
||||||
|
to how often we see it in our sampling. In other words, for tokens sampled
|
||||||
|
with vLLM's random SamplingType, the computed probability distribution
|
||||||
|
encodes the sampling methodology completely.
|
||||||
|
|
||||||
|
Greedy sampling does not normally have this property. vLLM modifies logits
|
||||||
|
according to sampling params, then performs `argmax`, then returns the
|
||||||
|
sampled token and the computed probability distribution. If we sample from
|
||||||
|
the distribution, we'll find the likelihood of the greedily-sampled token
|
||||||
|
is not always 1.0.
|
||||||
|
|
||||||
|
Since lossless speculative decoding requires that the sampling methodology
|
||||||
|
be encoded within the probability distribution, we are motivated to modify
|
||||||
|
the probability distribution such that the sampled token has probability 1
|
||||||
|
when speculative decoding is used.
|
||||||
|
|
||||||
|
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
||||||
|
greedy sampling using multinomial computation and unite the codepaths. This
|
||||||
|
has implications on the overall design of the sampler, e.g. how to record
|
||||||
|
accurate logprobs for the user, so this improvement is deferred to later.
|
||||||
|
"""
|
||||||
|
logprobs[sample_indices, :] = -float('inf')
|
||||||
|
logprobs[sample_indices, greedy_samples] = 0.0
|
||||||
|
probs[sample_indices, :] = 0
|
||||||
|
probs[sample_indices, greedy_samples] = 1.0
|
||||||
|
|
||||||
|
|
||||||
def _build_sampler_output(
|
def _build_sampler_output(
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: List[Tuple[List[int], List[int]]],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||||
sample_logprobs: List[SampleLogprobs],
|
sample_logprobs: List[SampleLogprobs],
|
||||||
|
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
"""Construct Python objects with the output of sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
on_device_tensors: Tuple containing on-device tensors with the
|
||||||
|
probabilities used in sampling and the sampled token ids. This
|
||||||
|
allows post-processing without copies to CPU/serialization, e.g. in
|
||||||
|
speculative decoding rejection sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
sampler_output = []
|
sampler_output = []
|
||||||
for (seq_group, sample_result, group_prompt_logprobs,
|
for (seq_group, sample_result, group_prompt_logprobs,
|
||||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||||
@ -701,4 +842,15 @@ def _build_sampler_output(
|
|||||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||||
sampler_output.append(
|
sampler_output.append(
|
||||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||||
return SamplerOutput(outputs=sampler_output)
|
|
||||||
|
# If not specified, store None values in SamplerOutput.
|
||||||
|
if on_device_tensors is not None:
|
||||||
|
sampled_token_probs, sampled_token_ids = on_device_tensors
|
||||||
|
else:
|
||||||
|
sampled_token_probs, sampled_token_ids = (None, None)
|
||||||
|
|
||||||
|
return SamplerOutput(
|
||||||
|
outputs=sampler_output,
|
||||||
|
sampled_token_probs=sampled_token_probs,
|
||||||
|
sampled_token_ids=sampled_token_ids,
|
||||||
|
)
|
||||||
|
|||||||
@ -6,8 +6,8 @@ import torch
|
|||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
|
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||||
nvtx_range, sampler_output_to_torch,
|
sampler_output_to_torch,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||||
|
|
||||||
|
# Filter the list to ignore -1 proposals.
|
||||||
|
proposal_token_ids_list_without_skips = [
|
||||||
|
proposals for proposals in proposal_token_ids_list
|
||||||
|
if -1 not in proposals
|
||||||
|
]
|
||||||
|
|
||||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
num_scoring_tokens) = self._expand_batch(
|
num_scoring_tokens) = self._expand_batch(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_token_ids_list=proposal_token_ids_list,
|
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||||
proposal_lens_list=proposal_lens_list,
|
proposal_lens_list=proposal_lens_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
target_sampler_output = target_sampler_output[0]
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
all_tokens, all_probs = self._contract_batch(
|
all_tokens, all_probs = self._contract_batch(
|
||||||
original_bs=len(seq_group_metadata_list),
|
contracted_bs=len(seq_group_metadata_list),
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
num_scoring_tokens=num_scoring_tokens,
|
num_scoring_tokens=num_scoring_tokens,
|
||||||
@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
select_proposal_len_zero=True)
|
select_proposal_len_zero=True)
|
||||||
|
|
||||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||||
spec_seqs, proposal_token_ids_list)
|
seq_group_metadata_list=spec_seqs,
|
||||||
|
proposal_token_ids=proposal_token_ids_list,
|
||||||
|
# NOTE: We determine the seq ids in the expanded batch using the
|
||||||
|
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||||
|
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||||
|
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||||
|
)
|
||||||
|
|
||||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||||
|
|
||||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
num_scoring_tokens)
|
num_scoring_tokens)
|
||||||
|
|
||||||
def _contract_batch(self, original_bs: int,
|
def _contract_batch(self, contracted_bs: int,
|
||||||
target_sampler_output: List[SamplerOutput],
|
target_sampler_output: List[SamplerOutput],
|
||||||
proposals: SpeculativeProposals,
|
proposals: SpeculativeProposals,
|
||||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||||
@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
"""Contract the expanded batch back into its original size.
|
"""Contract the expanded batch back into its original size.
|
||||||
This maps the scores of speculative tokens back to their original
|
This maps the scores of speculative tokens back to their original
|
||||||
sequences.
|
sequences.
|
||||||
|
|
||||||
|
contracted_bs is the original batch size, and the batch size that the
|
||||||
|
target_sampler_output will be contracted to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
|
||||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
|
||||||
maybe_mock_device_tensors(
|
|
||||||
sampler_output=target_sampler_output,
|
|
||||||
batch_size=len(non_spec_indices) + num_scoring_tokens,
|
|
||||||
vocab_size=self._vocab_size,
|
|
||||||
device=self._device,
|
|
||||||
)
|
|
||||||
|
|
||||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||||
non_spec_target_probs) = self._split_scoring_output(
|
non_spec_target_probs) = self._split_scoring_output(
|
||||||
target_sampler_output, num_scoring_tokens)
|
target_sampler_output, num_scoring_tokens)
|
||||||
|
|
||||||
# Map distinct sequences used to score each token
|
# Map distinct sequences used to score each token
|
||||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||||
batch_size, k = proposals.proposal_token_ids.shape
|
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||||
|
|
||||||
|
# The number of tokens in the expanded batch used for speculation is
|
||||||
|
# equal to the total expanded batch size minus the number of samples for
|
||||||
|
# non-speculative sequences.
|
||||||
|
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||||
|
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||||
|
|
||||||
target_token_ids = target_token_ids.squeeze().reshape(
|
target_token_ids = target_token_ids.squeeze().reshape(
|
||||||
batch_size, k + 1)
|
spec_expanded_bs, k + 1)
|
||||||
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
|
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||||
self._vocab_size)
|
self._vocab_size)
|
||||||
|
|
||||||
all_tokens = torch.full(size=(original_bs, k + 1),
|
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||||
fill_value=-1,
|
fill_value=-1,
|
||||||
device=self._device,
|
device=self._device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
all_probs = torch.zeros(original_bs,
|
all_probs = torch.zeros(contracted_bs,
|
||||||
k + 1,
|
k + 1,
|
||||||
self._vocab_size,
|
self._vocab_size,
|
||||||
device=self._device,
|
device=self._device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
||||||
if non_spec_indices:
|
if non_spec_indices:
|
||||||
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
|
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||||
|
|
||||||
if spec_indices:
|
if spec_indices:
|
||||||
@ -189,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
return all_tokens, all_probs
|
return all_tokens, all_probs
|
||||||
|
|
||||||
def _create_scoring_model_input(
|
def _create_scoring_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||||
|
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||||
) -> List[SequenceGroupMetadata]:
|
) -> List[SequenceGroupMetadata]:
|
||||||
"""Given the original input sequences and proposed tokens from the draft
|
"""Given the original input sequences and proposed tokens from the draft
|
||||||
model, create a list of target sequences that can be used for scoring.
|
model, create a list of target sequences that can be used for scoring.
|
||||||
|
|
||||||
|
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||||
|
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||||
|
to the seq id in the original batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not seq_group_metadata_list:
|
if not seq_group_metadata_list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
target_seq_ids_iter = self._create_target_seq_id_iterator(
|
|
||||||
get_all_seq_ids(seq_group_metadata_list))
|
|
||||||
|
|
||||||
target_seq_group_metadata = list(
|
target_seq_group_metadata = list(
|
||||||
chain.from_iterable(
|
chain.from_iterable(
|
||||||
self._create_target_seq_group_metadata(
|
self._create_target_seq_group_metadata(
|
||||||
|
|||||||
@ -24,9 +24,9 @@ class SpeculativeProposals:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"SpeculativeProposals("
|
return (f"SpeculativeProposals("
|
||||||
f"proposal_token_ids={self.proposal_token_ids.shape}, "
|
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||||
f"proposal_probs={self.proposal_probs.shape}, "
|
f"proposal_probs={self.proposal_probs.shape}, "
|
||||||
f"proposal_lens={self.proposal_lens.shape})")
|
f"proposal_lens={self.proposal_lens})")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -147,15 +147,16 @@ class AsyncMetricsCollector:
|
|||||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||||
draft_tokens = self._aggregate_num_draft_tokens
|
draft_tokens = self._aggregate_num_draft_tokens
|
||||||
|
|
||||||
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
|
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||||
|
draft_tokens, k)
|
||||||
|
|
||||||
if draft_tokens > 0:
|
if draft_tokens > 0:
|
||||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||||
else:
|
else:
|
||||||
draft_acceptance_rate = float("nan")
|
draft_acceptance_rate = float("nan")
|
||||||
|
|
||||||
if num_possible_tokens > 0:
|
if max_num_emitted_tokens > 0:
|
||||||
system_efficiency = emitted_tokens / num_possible_tokens
|
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||||
else:
|
else:
|
||||||
system_efficiency = float("nan")
|
system_efficiency = float("nan")
|
||||||
|
|
||||||
@ -169,8 +170,22 @@ class AsyncMetricsCollector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
|
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||||
# Divide by k since batch size can be variable.
|
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||||
total_num_spec_seqs = draft_tokens / k
|
accepted.
|
||||||
num_accepted_per_seq_if_all_accepted = k + 1
|
|
||||||
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
|
This is equal to the number of sequences that have been speculated on,
|
||||||
|
times (speculation len + 1). The +1 comes from the bonus token.
|
||||||
|
"""
|
||||||
|
# Determine the number of sequences that have been speculated on. Since
|
||||||
|
# the batch size can be variable, we divide by k.
|
||||||
|
assert draft_tokens % k == 0
|
||||||
|
total_num_spec_seqs = draft_tokens // k
|
||||||
|
|
||||||
|
# A single sequence may emit k accepted tokens and one bonus token in
|
||||||
|
# the best case.
|
||||||
|
num_emitted_per_seq_if_all_accepted = k + 1
|
||||||
|
|
||||||
|
# The max num of emitted tokens is the number of speculated sequences
|
||||||
|
# times the max emitted per seq.
|
||||||
|
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||||
|
|||||||
@ -6,8 +6,7 @@ import torch
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
from vllm.spec_decode.util import (maybe_mock_device_tensors,
|
from vllm.spec_decode.util import sampler_output_to_torch
|
||||||
sampler_output_to_torch)
|
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
@ -329,12 +328,15 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
|||||||
"""
|
"""
|
||||||
if maybe_sampler_output is None:
|
if maybe_sampler_output is None:
|
||||||
# If no speculative tokens, the sampler output will be None.
|
# If no speculative tokens, the sampler output will be None.
|
||||||
# In this case we return empty tensors.
|
# In this case we return empty proposals.
|
||||||
proposal_tokens = torch.zeros(0,
|
proposal_tokens = torch.full(size=(
|
||||||
max_proposal_len,
|
batch_size,
|
||||||
dtype=torch.long,
|
max_proposal_len,
|
||||||
device=self._device)
|
),
|
||||||
proposal_probs = torch.zeros(0,
|
fill_value=-1,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._device)
|
||||||
|
proposal_probs = torch.zeros(batch_size,
|
||||||
max_proposal_len,
|
max_proposal_len,
|
||||||
self._vocab_size,
|
self._vocab_size,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -345,17 +347,6 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
|||||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||||
|
|
||||||
sampler_output = maybe_sampler_output
|
sampler_output = maybe_sampler_output
|
||||||
|
|
||||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
|
||||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
|
||||||
for step_output in sampler_output:
|
|
||||||
maybe_mock_device_tensors(
|
|
||||||
sampler_output=step_output,
|
|
||||||
batch_size=len(proposal_lens),
|
|
||||||
vocab_size=self._vocab_size,
|
|
||||||
device=self._device,
|
|
||||||
)
|
|
||||||
|
|
||||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||||
sampler_output)
|
sampler_output)
|
||||||
|
|
||||||
|
|||||||
@ -111,6 +111,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
vocab_size=self._vocab_size)
|
vocab_size=self._vocab_size)
|
||||||
|
|
||||||
|
self._configure_model_sampler_for_spec_decode()
|
||||||
|
|
||||||
|
def _configure_model_sampler_for_spec_decode(self):
|
||||||
|
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||||
|
to keep data on device without transferring to CPU and serializing,
|
||||||
|
which significantly reduces overhead of rejection sampling.
|
||||||
|
|
||||||
|
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||||
|
design is to have the "move to CPU and serialize" sampling decision be
|
||||||
|
done outside of the model/sampler; this way the "last-mile" worker
|
||||||
|
object which interfaces with the scheduler can serialize and incur the
|
||||||
|
performance hit as necessary. This allows us to run the worker several
|
||||||
|
iterations in a row without incurring the "move to CPU and serialize"
|
||||||
|
performance penalty.
|
||||||
|
|
||||||
|
Since this requires a large change to vLLM, we defer it to later and
|
||||||
|
temporarily accept this broken abstraction boundary.
|
||||||
|
|
||||||
|
NOTE(cade): This will require a special check if the proposer worker
|
||||||
|
does not have a sampler (e.g. ngram speculation).
|
||||||
|
"""
|
||||||
|
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||||
|
) = True
|
||||||
|
(self.proposer_worker.model_runner.model.sampler.
|
||||||
|
include_gpu_probs_tensor) = True
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of cache blocks to use.
|
"""Determine the number of cache blocks to use.
|
||||||
|
|
||||||
@ -286,15 +312,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
select_proposal_len_zero=True)
|
select_proposal_len_zero=True)
|
||||||
original_indices = spec_indices + non_spec_indices
|
original_indices = spec_indices + non_spec_indices
|
||||||
|
|
||||||
proposal_probs = proposal_scores.probs[spec_indices, :-1]
|
# Get probabilities of target model, excluding bonus token.
|
||||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
||||||
|
|
||||||
|
# Get non-speculative sampled tokens from target model.
|
||||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||||
|
|
||||||
|
# Get bonus tokens from target model.
|
||||||
|
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||||
|
|
||||||
|
# Get probabilities according to proposal method.
|
||||||
|
proposal_probs = proposals.proposal_probs[spec_indices]
|
||||||
|
|
||||||
|
# Get proposed tokens.
|
||||||
|
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||||
|
|
||||||
accepted_token_ids = self.rejection_sampler(
|
accepted_token_ids = self.rejection_sampler(
|
||||||
proposal_probs,
|
target_probs=proposal_verifier_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids=bonus_token_ids,
|
||||||
proposals.proposal_probs,
|
draft_probs=proposal_probs,
|
||||||
proposals.proposal_token_ids,
|
draft_token_ids=proposal_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append output tokens from non-speculative sequences to
|
# Append output tokens from non-speculative sequences to
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user