[V1][Spec Decode] Handle draft tokens beyond max_model_len (#16087)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-04-21 12:38:50 -07:00 committed by GitHub
parent 299ebb62b2
commit 3a0fba5cf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 137 additions and 15 deletions

View File

@ -30,6 +30,7 @@ def create_scheduler(
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
) -> Scheduler:
'''Create scheduler under test.
@ -44,12 +45,15 @@ def create_scheduler(
Returns:
:class:`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
)
model_config = ModelConfig(
model=model,
@ -296,6 +300,7 @@ def test_no_mm_input_chunking():
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
disable_chunked_mm_input=True,
max_model_len=2048,
)
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
requests = create_requests(num_requests=1,

View File

@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import pytest
from vllm import LLM, SamplingParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)

View File

@ -2,7 +2,7 @@
import numpy as np
from vllm.config import SpeculativeConfig, VllmConfig
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array)
@ -42,14 +42,24 @@ def test_find_subarray_kmp():
def test_ngram_proposer():
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
return NgramProposer(vllm_config=VllmConfig(
speculative_config=SpeculativeConfig.from_dict(
{
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m",
task="generate",
max_model_len=100,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False)
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig.
from_dict({
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match.
result = ngram_proposer(

View File

@ -185,6 +185,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,

View File

@ -12,6 +12,8 @@ from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
PADDING_SLOT_ID = -1
class EagleProposer:
@ -23,6 +25,7 @@ class EagleProposer:
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
@ -112,22 +115,48 @@ class EagleProposer:
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = positions // self.block_size
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
positions=clamped_positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(

View File

@ -18,6 +18,9 @@ class NgramProposer:
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
@ -50,9 +53,14 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
if k <= 0:
return None
# TODO(woosuk): Optimize this.
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, self.k)
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None

View File

@ -1271,7 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append([])
continue
# Skip requests that require top-p, top-k, etc.
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
@ -1280,6 +1281,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
if end_idx >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])