[V1][Spec decode] Move drafter to model runner (#13363)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-17 15:40:12 -08:00 committed by GitHub
parent 6ac485a953
commit cd4a72a28d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 84 additions and 57 deletions

View File

@ -203,6 +203,7 @@ def test_schedule_partial_requests():
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)

View File

@ -474,6 +474,7 @@ class Scheduler:
model_runner_output: "ModelRunnerOutput",
) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@ -530,13 +531,9 @@ class Scheduler:
self.encoder_cache_manager.free_encoder_input(
request, input_id)
if request.num_computed_tokens >= request.num_tokens:
# Clear the spec tokens as the request has generated
# a new token. Here, We assume all spec tokens are verified
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request.clear_spec_tokens()
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)

View File

@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -86,15 +85,6 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self.use_spec_decode = False
if self.scheduler.speculative_config:
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
, "Only ngram spec decode is supported in V1."
self.proposer = NgramProposer()
self.use_spec_decode = True
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
@ -158,9 +148,6 @@ class EngineCore:
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
if self.use_spec_decode:
self.propose_tokens()
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@ -221,23 +208,6 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def propose_tokens(self):
assert self.scheduler.speculative_config is not None
for req in self.scheduler.running:
# Ignore requests that are doing chunked prefill.
if req.num_computed_tokens < req.num_tokens - 1:
continue
# Ignore requests that already have spec tokens.
if req.spec_token_ids:
continue
spec_tokens = self.proposer.propose(
req.all_token_ids,
self.scheduler.speculative_config.ngram_prompt_lookup_min,
self.scheduler.speculative_config.num_speculative_tokens,
)
if spec_tokens:
req.append_spec_token_ids(spec_tokens)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

View File

@ -67,6 +67,9 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding.
sampled_token_ids: List[List[int]]
# num_reqs x num_spec_tokens
spec_token_ids: Optional[List[List[int]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]

View File

@ -104,18 +104,6 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
def append_spec_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
self.spec_token_ids.append(token_ids)
else:
self.spec_token_ids.extend(token_ids)
def clear_spec_tokens(self) -> None:
self.spec_token_ids.clear()
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from vllm.v1.utils import ConstantList
import numpy as np
class NgramProposer:
@ -9,8 +9,12 @@ class NgramProposer:
def __init__(self):
pass
def propose(self, context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
def propose(
self,
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
@ -25,8 +29,8 @@ class NgramProposer:
the maximum amount of tokens until the end.
Returns:
List[int]: The sequence of tokens that followed
the matched n-gram in the context.
np.ndarray: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
@ -66,9 +70,12 @@ class NgramProposer:
return lps
@staticmethod
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
context_len = len(context_token_ids)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
pattern = context_token_ids[-n:]

View File

@ -78,6 +78,7 @@ class InputBatch:
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
@ -217,7 +218,11 @@ class InputBatch:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(req_index, request.block_ids)
@ -356,6 +361,8 @@ class InputBatch:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[

View File

@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -117,6 +118,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer()
self.use_spec_decode = True
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
@ -367,6 +377,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
@ -1009,15 +1020,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
if not self.use_spec_decode:
spec_token_ids = None
else:
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output
def generate_draft_token_ids(
self,
sampled_token_ids: List[List[int]],
) -> List[List[int]]:
# TODO(woosuk): Optimize.
num_reqs = len(sampled_token_ids)
draft_token_ids: List[List[int]] = []
for i in range(num_reqs):
if len(sampled_token_ids[i]) == 0:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + len(sampled_token_ids[i])
self.input_batch.token_ids_cpu[
i, start_idx:end_idx] = sampled_token_ids[i]
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117

View File

@ -696,6 +696,7 @@ class TPUModelRunner:
req_ids=all_req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
)