mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 16:09:44 +08:00
[V1][Spec Decode] Handle draft tokens beyond max_model_len (#16087)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
299ebb62b2
commit
3a0fba5cf4
@ -30,6 +30,7 @@ def create_scheduler(
|
||||
use_kv_connector: bool = False,
|
||||
num_blocks: int = 10000,
|
||||
block_size: int = 16,
|
||||
max_model_len: Optional[int] = None,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
@ -44,12 +45,15 @@ def create_scheduler(
|
||||
Returns:
|
||||
:class:`Scheduler` instance
|
||||
'''
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
@ -296,6 +300,7 @@ def test_no_mm_input_chunking():
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_num_batched_tokens=1024,
|
||||
disable_chunked_mm_input=True,
|
||||
max_model_len=2048,
|
||||
)
|
||||
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
|
||||
requests = create_requests(num_requests=1,
|
||||
|
||||
57
tests/v1/spec_decode/test_max_len.py
Normal file
57
tests/v1/spec_decode/test_max_len.py
Normal file
@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test whether spec decoding handles the max model length properly."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
_PROMPTS = [
|
||||
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
||||
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
|
||||
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_ngram_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_eagle_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "eagle",
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
max_model_len=100,
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import SpeculativeConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
||||
_find_subarray_kmp,
|
||||
_kmp_lps_array)
|
||||
@ -42,14 +42,24 @@ def test_find_subarray_kmp():
|
||||
def test_ngram_proposer():
|
||||
|
||||
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
return NgramProposer(vllm_config=VllmConfig(
|
||||
speculative_config=SpeculativeConfig.from_dict(
|
||||
{
|
||||
"prompt_lookup_min": min_n,
|
||||
"prompt_lookup_max": max_n,
|
||||
"num_speculative_tokens": k,
|
||||
"method": "ngram",
|
||||
})))
|
||||
# Dummy model config. Just to set max_model_len.
|
||||
model_config = ModelConfig(model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=100,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
trust_remote_code=False)
|
||||
return NgramProposer(
|
||||
vllm_config=VllmConfig(model_config=model_config,
|
||||
speculative_config=SpeculativeConfig.
|
||||
from_dict({
|
||||
"prompt_lookup_min": min_n,
|
||||
"prompt_lookup_max": max_n,
|
||||
"num_speculative_tokens": k,
|
||||
"method": "ngram",
|
||||
})))
|
||||
|
||||
# No match.
|
||||
result = ngram_proposer(
|
||||
|
||||
@ -185,6 +185,13 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
|
||||
@ -12,6 +12,8 @@ from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class EagleProposer:
|
||||
|
||||
@ -23,6 +25,7 @@ class EagleProposer:
|
||||
self.vllm_config = vllm_config
|
||||
self.num_speculative_tokens = (
|
||||
vllm_config.speculative_config.num_speculative_tokens)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
@ -112,22 +115,48 @@ class EagleProposer:
|
||||
# Update the inputs.
|
||||
input_ids = draft_token_ids_list[-1]
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
attn_metadata.seq_lens += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = positions // self.block_size
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
positions % self.block_size)
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
hidden_states=hidden_states,
|
||||
positions=positions,
|
||||
positions=clamped_positions,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
draft_token_ids, probs = compute_probs_and_sample_next_token(
|
||||
|
||||
@ -18,6 +18,9 @@ class NgramProposer:
|
||||
# tokens follow the match, we will return the maximum amount of
|
||||
# tokens until the end.
|
||||
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||
# Maximum length of the model.
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(np.zeros(1024, dtype=np.int32))
|
||||
@ -50,9 +53,14 @@ class NgramProposer:
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
|
||||
if k <= 0:
|
||||
return None
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
for n in range(self.max_n, self.min_n - 1, -1):
|
||||
result = _find_subarray_kmp(context_token_ids, n, self.k)
|
||||
result = _find_subarray_kmp(context_token_ids, n, k)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
@ -1271,7 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require top-p, top-k, etc.
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if not is_spec_decode_supported(req_id, self.input_batch):
|
||||
draft_token_ids.append([])
|
||||
@ -1280,6 +1281,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
if end_idx >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user