mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:04:57 +08:00
[Spec Decode] Integrate Suffix Decoding from Arctic Inference (#25784)
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
parent
4bc400f47e
commit
2c19d96777
@ -130,6 +130,46 @@ matching n-grams in the prompt. For more information read [this thread.](https:/
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
## Speculating using Suffix Decoding
|
||||
|
||||
The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)).
|
||||
|
||||
Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates.
|
||||
|
||||
Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts.
|
||||
|
||||
!!! tip "Install Arctic Inference"
|
||||
Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`.
|
||||
|
||||
!!! tip "Suffix Decoding Speculative Tokens"
|
||||
Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default).
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(
|
||||
model="facebook/opt-6.7b",
|
||||
tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"method": "suffix",
|
||||
"num_speculative_tokens": 32,
|
||||
},
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
## Speculating using MLP speculators
|
||||
|
||||
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||||
|
||||
@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9
|
||||
genai_perf==0.0.8
|
||||
tritonclient==2.51.0
|
||||
|
||||
arctic-inference == 0.1.0 # Required for suffix decoding test
|
||||
numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs]==0.15.0
|
||||
|
||||
@ -40,6 +40,8 @@ anyio==4.6.2.post1
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
arctic-inference==0.1.0
|
||||
# via -r requirements/test.in
|
||||
argcomplete==3.5.1
|
||||
# via datamodel-code-generator
|
||||
arrow==1.3.0
|
||||
|
||||
@ -75,7 +75,23 @@ def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
@pytest.mark.parametrize(
|
||||
"speculative_config",
|
||||
[
|
||||
{
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
{
|
||||
"method": "suffix",
|
||||
"suffix_decoding_max_spec_factor": 2.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_ngram_and_suffix_correctness(
|
||||
speculative_config: dict,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
@ -94,12 +110,7 @@ def test_ngram_correctness(
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
speculative_config=speculative_config,
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
@ -121,6 +132,66 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_suffix_decoding_acceptance(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
"""
|
||||
Check that suffix decoding caching takes effect and improves acceptance
|
||||
lengths and acceptance rates over multiple runs of the same prompts.
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "suffix",
|
||||
"suffix_decoding_max_spec_factor": 2.0,
|
||||
"suffix_decoding_max_cached_requests": 1000,
|
||||
},
|
||||
max_model_len=1024,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
# Run several times and check that the accepted tokens increase.
|
||||
spec_llm.chat(test_prompts, sampling_config)
|
||||
num_draft = []
|
||||
num_accept = []
|
||||
for i in range(10): # Run multiple times to warm up the cache.
|
||||
spec_llm.chat(test_prompts, sampling_config)
|
||||
# Collect draft and acceptance stats.
|
||||
metrics = spec_llm.get_metrics()
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_draft_tokens":
|
||||
num_draft.append(metric.value)
|
||||
if metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||
num_accept.append(metric.value)
|
||||
|
||||
# Calculate the acceptance rates for the first and last runs.
|
||||
first_accept_tokens = num_accept[0]
|
||||
first_draft_tokens = num_draft[0]
|
||||
first_accept_rate = first_accept_tokens / first_draft_tokens
|
||||
|
||||
# Take the diff since the stats are cumulative.
|
||||
last_accept_tokens = num_accept[-1] - num_accept[-2]
|
||||
last_draft_tokens = num_draft[-1] - num_draft[-2]
|
||||
last_accept_rate = last_accept_tokens / last_draft_tokens
|
||||
|
||||
# Expect the acceptance length to improve.
|
||||
assert first_accept_tokens < last_accept_tokens
|
||||
|
||||
# Expect the acceptance rate to improve.
|
||||
assert first_accept_rate < last_accept_rate
|
||||
|
||||
# Heuristic: expect at least 85% acceptance rate at the end.
|
||||
assert last_accept_rate > 0.85
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
|
||||
@ -12,7 +12,7 @@ from typing_extensions import Self
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
@ -42,6 +42,7 @@ SpeculativeMethod = Literal[
|
||||
"mimo_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"suffix",
|
||||
]
|
||||
MTP_MODEL_TYPES = (
|
||||
"deepseek_mtp",
|
||||
@ -129,6 +130,27 @@ class SpeculativeConfig:
|
||||
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the draft model initialized internal."""
|
||||
|
||||
# Suffix decoding configuration
|
||||
suffix_decoding_max_tree_depth: int = 24
|
||||
"""The maximum depth of the suffix decoding global and prompt trees. The
|
||||
tree depth limits the sum of the prefix match and speculation lengths."""
|
||||
|
||||
suffix_decoding_max_cached_requests: int = 10000
|
||||
"""The maximum number of requests to cache in the global suffix tree. If
|
||||
exceeded, will trigger eviction in FIFO order. If set to 0, the global
|
||||
suffix tree is disabled and past responses are not cached (prompt trees
|
||||
are still used)."""
|
||||
|
||||
suffix_decoding_max_spec_factor: float = 1.0
|
||||
"""The maximum spec factor for suffix decoding. The spec factor controls
|
||||
speculation lengths based on the prefix match length: max_spec_tokens =
|
||||
max_spec_factor * prefix_match_length."""
|
||||
|
||||
suffix_decoding_min_token_prob: float = 0.1
|
||||
"""The minimum token probability for suffix decoding. Will only speculate
|
||||
tokens with estimated probability (based on frequency counts) greater than
|
||||
or equal to this value."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -235,6 +257,8 @@ class SpeculativeConfig:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
@ -282,6 +306,8 @@ class SpeculativeConfig:
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
@ -430,6 +456,42 @@ class SpeculativeConfig:
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_suffix_decoding(self):
|
||||
if not has_arctic_inference():
|
||||
raise ImportError(
|
||||
"Arctic Inference is required for suffix decoding. "
|
||||
"Install via `pip install arctic-inference==0.1.0`."
|
||||
)
|
||||
if self.num_speculative_tokens is None:
|
||||
# Suffix decoding decides the actual number of speculative tokens
|
||||
# dynamically and treats num_speculative_tokens as a maximum limit.
|
||||
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
|
||||
logger.warning(
|
||||
"Defaulted num_speculative_tokens to %s for suffix decoding.",
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
# Validate values
|
||||
if self.suffix_decoding_max_tree_depth < 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_tree_depth="
|
||||
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
|
||||
)
|
||||
if self.suffix_decoding_max_cached_requests < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_cached_requests="
|
||||
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
|
||||
)
|
||||
if self.suffix_decoding_max_spec_factor < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_spec_factor="
|
||||
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
|
||||
)
|
||||
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_min_token_prob="
|
||||
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: int | None,
|
||||
@ -582,6 +644,6 @@ class SpeculativeConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method == "ngram" else self.draft_model_config.model
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
@ -403,3 +403,9 @@ def has_triton_kernels() -> bool:
|
||||
def has_tilelang() -> bool:
|
||||
"""Whether the optional `tilelang` package is available."""
|
||||
return _has_module("tilelang")
|
||||
|
||||
|
||||
def has_arctic_inference() -> bool:
|
||||
"""Whether the optional `arctic_inference` package is available."""
|
||||
|
||||
return _has_module("arctic_inference")
|
||||
|
||||
101
vllm/v1/spec_decode/suffix_decoding.py
Normal file
101
vllm/v1/spec_decode/suffix_decoding.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class SuffixDecodingProposer:
|
||||
"""
|
||||
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
|
||||
This class imports and uses the official implementation from Arctic Inference
|
||||
(https://github.com/snowflakedb/ArcticInference).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
config = vllm_config.speculative_config
|
||||
self.num_speculative_tokens = config.num_speculative_tokens
|
||||
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
||||
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
||||
self.min_token_prob = config.suffix_decoding_min_token_prob
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Lazy import to avoid error when Suffix Decoding is not used.
|
||||
from arctic_inference.suffix_decoding import SuffixDecodingCache
|
||||
|
||||
# Initialize and empty cache. This object will take care of caching request
|
||||
# outputs, evicting old requests, and manages the per-prompt suffix trees.
|
||||
self.suffix_cache = SuffixDecodingCache(
|
||||
max_tree_depth=config.suffix_decoding_max_tree_depth,
|
||||
max_cached_requests=config.suffix_decoding_max_cached_requests,
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Propose speculative tokens for each request in the input batch. Suffix Decoding
|
||||
will speculate a dynamic number of tokens for each request every decoding step,
|
||||
so each entry in the returned list may have different lengths.
|
||||
"""
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
# Skip speculative decoding for partial prefills.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = input_batch.req_ids[i]
|
||||
if req_id in input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
num_tokens = input_batch.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
index = input_batch.req_id_to_index[req_id]
|
||||
if req_id not in self.suffix_cache.active_requests:
|
||||
if req_id in self.suffix_cache.cached_requests:
|
||||
# Reset the suffix cache for this request.
|
||||
self.suffix_cache.evict_cached_response(req_id)
|
||||
num_prompt_tokens = input_batch.num_prompt_tokens[index]
|
||||
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
|
||||
# Start a new request, this will build the suffix tree for that prompt.
|
||||
self.suffix_cache.start_request(req_id, prompt_token_ids)
|
||||
|
||||
# Append the newly sampled ids to the suffix cache for this request.
|
||||
self.suffix_cache.add_active_response(req_id, sampled_ids)
|
||||
|
||||
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
|
||||
# we extract the pattern from the end of the input.
|
||||
start = max(0, num_tokens - self.max_tree_depth)
|
||||
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
|
||||
draft = self.suffix_cache.speculate(
|
||||
req_id,
|
||||
pattern,
|
||||
max_spec_tokens=min(
|
||||
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
|
||||
),
|
||||
max_spec_factor=self.max_spec_factor,
|
||||
min_token_prob=self.min_token_prob,
|
||||
)
|
||||
|
||||
draft_token_ids.append(draft.token_ids)
|
||||
|
||||
# Stop requests that were not seen in the input batch.
|
||||
for req_id in (
|
||||
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
|
||||
):
|
||||
self.suffix_cache.stop_request(req_id)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
@ -125,6 +125,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
@ -336,16 +337,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# the last PP rank. This is not ideal if there are many
|
||||
# layers in the draft model.
|
||||
if self.speculative_config and get_pp_group().is_last_rank:
|
||||
self.drafter: (
|
||||
NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
|
||||
)
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.method == "suffix":
|
||||
self.drafter = SuffixDecodingProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device, self)
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config, device=self.device
|
||||
) # type: ignore
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown speculative decoding method: "
|
||||
@ -2783,6 +2789,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_batch.token_ids_cpu,
|
||||
self.input_batch.spec_decode_unsupported_reqs,
|
||||
)
|
||||
elif self.speculative_config.method == "suffix":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, SuffixDecodingProposer)
|
||||
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user