mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-12 00:14:42 +08:00
[V1][Spec decode] Move drafter to model runner (#13363)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
6ac485a953
commit
cd4a72a28d
@ -203,6 +203,7 @@ def test_schedule_partial_requests():
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
|
||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||
[10,
|
||||
11]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
@ -474,6 +474,7 @@ class Scheduler:
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> EngineCoreOutputs:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
@ -530,13 +531,9 @@ class Scheduler:
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
if request.num_computed_tokens >= request.num_tokens:
|
||||
# Clear the spec tokens as the request has generated
|
||||
# a new token. Here, We assume all spec tokens are verified
|
||||
# if we perform speculative decoding for this request.
|
||||
# Therefore, we can clear all spec tokens after
|
||||
# the generation step.
|
||||
request.clear_spec_tokens()
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
|
||||
@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -86,15 +85,6 @@ class EngineCore:
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
# Setup speculative decode.
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
self.use_spec_decode = False
|
||||
if self.scheduler.speculative_config:
|
||||
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
|
||||
, "Only ngram spec decode is supported in V1."
|
||||
self.proposer = NgramProposer()
|
||||
self.use_spec_decode = True
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
@ -158,9 +148,6 @@ class EngineCore:
|
||||
return EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
|
||||
if self.use_spec_decode:
|
||||
self.propose_tokens()
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
@ -221,23 +208,6 @@ class EngineCore:
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
def propose_tokens(self):
|
||||
assert self.scheduler.speculative_config is not None
|
||||
for req in self.scheduler.running:
|
||||
# Ignore requests that are doing chunked prefill.
|
||||
if req.num_computed_tokens < req.num_tokens - 1:
|
||||
continue
|
||||
# Ignore requests that already have spec tokens.
|
||||
if req.spec_token_ids:
|
||||
continue
|
||||
spec_tokens = self.proposer.propose(
|
||||
req.all_token_ids,
|
||||
self.scheduler.speculative_config.ngram_prompt_lookup_min,
|
||||
self.scheduler.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if spec_tokens:
|
||||
req.append_spec_token_ids(spec_tokens)
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
|
||||
@ -67,6 +67,9 @@ class ModelRunnerOutput:
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: List[List[int]]
|
||||
|
||||
# num_reqs x num_spec_tokens
|
||||
spec_token_ids: Optional[List[List[int]]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs]
|
||||
|
||||
@ -104,18 +104,6 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
def append_spec_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
self.spec_token_ids.append(token_ids)
|
||||
else:
|
||||
self.spec_token_ids.extend(token_ids)
|
||||
|
||||
def clear_spec_tokens(self) -> None:
|
||||
self.spec_token_ids.clear()
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self._all_token_ids)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.v1.utils import ConstantList
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NgramProposer:
|
||||
@ -9,8 +9,12 @@ class NgramProposer:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def propose(self, context_token_ids: ConstantList[int], n: int,
|
||||
k: int) -> Optional[List[int]]:
|
||||
def propose(
|
||||
self,
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||
matching in the context. The function finds matches of the last n
|
||||
tokens in the previous context, and returns k tokens that followed
|
||||
@ -25,8 +29,8 @@ class NgramProposer:
|
||||
the maximum amount of tokens until the end.
|
||||
|
||||
Returns:
|
||||
List[int]: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
np.ndarray: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
None: If no matching n-gram pattern is found.
|
||||
|
||||
Example:
|
||||
@ -66,9 +70,12 @@ class NgramProposer:
|
||||
return lps
|
||||
|
||||
@staticmethod
|
||||
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
|
||||
k: int) -> Optional[List[int]]:
|
||||
context_len = len(context_token_ids)
|
||||
def _find_subarray_kmp(
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
context_len = context_token_ids.shape[0]
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
|
||||
@ -78,6 +78,7 @@ class InputBatch:
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
|
||||
@ -217,7 +218,11 @@ class InputBatch:
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(req_index, request.block_ids)
|
||||
@ -356,6 +361,8 @@ class InputBatch:
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -117,6 +118,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
self.drafter = NgramProposer()
|
||||
self.use_spec_decode = True
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
@ -367,6 +377,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_token_index:end_token_index] = req_data.new_token_ids
|
||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, [])
|
||||
@ -1009,15 +1020,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||
]
|
||||
|
||||
if not self.use_spec_decode:
|
||||
spec_token_ids = None
|
||||
else:
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
return model_runner_output
|
||||
|
||||
def generate_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: List[List[int]],
|
||||
) -> List[List[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
num_reqs = len(sampled_token_ids)
|
||||
draft_token_ids: List[List[int]] = []
|
||||
for i in range(num_reqs):
|
||||
if len(sampled_token_ids[i]) == 0:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + len(sampled_token_ids[i])
|
||||
self.input_batch.token_ids_cpu[
|
||||
i, start_idx:end_idx] = sampled_token_ids[i]
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
|
||||
@ -696,6 +696,7 @@ class TPUModelRunner:
|
||||
req_ids=all_req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user