[Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. (#24986)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Ekagra Ranjan 2025-09-25 18:22:03 -04:00 committed by GitHub
parent 89fa54e6f7
commit e71b8e210d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 383 additions and 109 deletions

View File

@ -1,17 +1,31 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import time
from unittest import mock
import numpy as np import numpy as np
from tabulate import tabulate from tabulate import tabulate
from benchmark_utils import TimeCollector from benchmark_utils import TimeCollector
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def main(args): def benchmark_propose(args):
rows = [] rows = []
for max_ngram in args.max_ngram: for max_ngram in args.max_ngram:
collector = TimeCollector(TimeCollector.US) collector = TimeCollector(TimeCollector.US)
@ -69,10 +83,88 @@ def main(args):
) )
def benchmark_batched_propose(args):
NUM_SPECULATIVE_TOKENS_NGRAM = 10
PROMPT_LOOKUP_MIN = 5
PROMPT_LOOKUP_MAX = 15
MAX_MODEL_LEN = int(1e7)
DEVICE = current_platform.device_type
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="ngram",
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
prompt_lookup_max=PROMPT_LOOKUP_MAX,
prompt_lookup_min=PROMPT_LOOKUP_MIN,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(),
)
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
with mock.patch(
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
):
runner = GPUModelRunner(vllm_config, DEVICE)
# hack max model len
runner.max_model_len = MAX_MODEL_LEN
runner.drafter.max_model_len = MAX_MODEL_LEN
dummy_input_batch = InputBatch(
max_num_reqs=args.num_req,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=args.num_req * args.num_token,
device=DEVICE,
pin_memory=False,
vocab_size=256000,
block_sizes=[16],
)
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
dummy_input_batch.spec_decode_unsupported_reqs = ()
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
dummy_input_batch.token_ids_cpu = np.random.randint(
0, 20, (args.num_req, args.num_token)
)
runner.input_batch = dummy_input_batch
sampled_token_ids = [[0]] * args.num_req
print("Starting benchmark")
# first run is warmup so ignore it
for _ in range(args.num_iteration):
start = time.time()
runner.drafter.propose(
sampled_token_ids,
dummy_input_batch.req_ids,
dummy_input_batch.num_tokens_no_spec,
dummy_input_batch.token_ids_cpu,
dummy_input_batch.spec_decode_unsupported_reqs,
)
end = time.time()
print(f"Iteration time (s): {end - start}")
def invoke_main() -> None: def invoke_main() -> None:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the performance of N-gram speculative decode drafting" description="Benchmark the performance of N-gram speculative decode drafting"
) )
parser.add_argument(
"--batched", action="store_true", help="consider time to prepare batch"
) # noqa: E501
parser.add_argument( parser.add_argument(
"--num-iteration", "--num-iteration",
type=int, type=int,
@ -105,8 +197,17 @@ def invoke_main() -> None:
help="Number of speculative tokens to generate", help="Number of speculative tokens to generate",
) )
args = parser.parse_args() args = parser.parse_args()
main(args)
if not args.batched:
benchmark_propose(args)
else:
benchmark_batched_propose(args)
"""
# Example command lines:
# time python3 benchmarks/benchmark_ngram_proposer.py
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
""" # noqa: E501
if __name__ == "__main__": if __name__ == "__main__":
invoke_main() # pragma: no cover invoke_main() # pragma: no cover

View File

@ -9,11 +9,13 @@ from vllm.v1.spec_decode.ngram_proposer import (
def test_find_longest_matched_ngram_and_propose_tokens(): def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, result = _find_longest_matched_ngram_and_propose_tokens(
min_ngram=2, origin_tokens=tokens,
max_ngram=2, min_ngram=2,
max_model_len=1024, max_ngram=2,
k=2) is None max_model_len=1024,
k=2)
assert len(result) == 0
tokens = np.array([1, 2, 3, 4, 1, 2, 3]) tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal( np.testing.assert_array_equal(
@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():
def test_ngram_proposer(): def test_ngram_proposer():
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len. # Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m") model_config = ModelConfig(model="facebook/opt-125m")
return NgramProposer( return NgramProposer(
@ -75,36 +77,120 @@ def test_ngram_proposer():
))) )))
# No match. # No match.
result = ngram_proposer( token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
min_n=2, max_n=2, result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) sampled_token_ids=[[0]],
assert result is None req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram. # No match for 4-gram.
result = ngram_proposer( token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
min_n=4, max_n=4, result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) sampled_token_ids=[[0]],
assert result is None req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram but match for 3-gram. # No match for 4-gram but match for 3-gram.
result = ngram_proposer( token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
min_n=3, max_n=4, result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) sampled_token_ids=[[0]],
assert np.array_equal(result, np.array([4, 1])) req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[4, 1]]))
# Match for both 4-gram and 3-gram. # Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match. # In this case, the proposer should return the 4-gram match.
result = ngram_proposer(min_n=3, max_n=4, k=2).propose( token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
# Match for 2-gram and 3-gram, but not 4-gram. # Match for 2-gram and 3-gram, but not 4-gram.
result = ngram_proposer(min_n=2, max_n=4, k=2).propose( token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
# Multiple 3-gram matched, but always pick the first one. # Multiple 3-gram matched, but always pick the first one.
result = ngram_proposer( token_ids_cpu = np.array(
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
assert np.array_equal(result, np.array([100, 1])) sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[100, 1]]))
# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# check multibatch input
# first request has 5 tokens and a match
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1]))
assert np.array_equal(result[1], np.array([]))
# test if 0 threads available: can happen if TP size > CPU count
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
ngram_proposer.num_numba_thread_available = 0
# set max_model_len to 2 * threshold to ensure multithread is used
num_tokens_threshold = ngram_proposer.num_tokens_threshold
ngram_proposer.max_model_len = 2 * num_tokens_threshold
# using multibatch test
middle_integer = num_tokens_threshold // 2
input_1 = [_ for _ in range(num_tokens_threshold)]
input_1 += [middle_integer, middle_integer + 1]
input_2 = [-1] * len(input_1)
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0],
np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[1], np.array([]))

View File

@ -17,7 +17,7 @@ PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single # Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases. # step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32 MAX_SPEC_LEN = 128
class RejectionSampler(nn.Module): class RejectionSampler(nn.Module):

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional import os
import numpy as np import numpy as np
from numba import jit from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -26,55 +26,174 @@ class NgramProposer:
# Maximum length of the model. # Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
# Pre-allocate buffers for numba batch propose.
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k),
dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
# Divide by 2 to use physical cores
# and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else:
self.num_numba_thread_available = 1
# Trigger Numba JIT compilation for N-gram proposer. # Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second. # This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32)) self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
set())
def batch_propose(
self,
num_requests: int,
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
) -> list[list[int]]:
"""Batch version of ngram proposer using numba for acceleration.
Args:
valid_ngram_requests:
Set of indices of requests that need ngram proposals.
num_tokens_no_spec:
Numpy array of shape (batch_size,) representing the number
of tokens without speculative tokens for each request.
token_ids_cpu:
Numpy array of shape (batch_size, max_model_len)
representing the token IDs for each request.
Returns:
list[list[int]]:
A list where each element is a list of proposed
token IDs for the corresponding request.
"""
draft_token_ids: list[list[int]] = []
# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available,
num_ngram_requests))
set_num_threads(final_num_threads)
else:
set_num_threads(1)
batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
token_ids_cpu, self.min_n, self.max_n,
self.max_model_len, self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts)
# Restore original number of threads.
set_num_threads(original_num_numba_threads)
for i in range(num_requests):
if i in valid_ngram_requests and \
self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(self.valid_ngram_draft[
i, :self.valid_ngram_num_drafts[i]].tolist())
else:
draft_token_ids.append([])
return draft_token_ids
def propose( def propose(
self, self,
context_token_ids: np.ndarray, sampled_token_ids: list[list[int]],
) -> Optional[np.ndarray]: req_ids: list[str],
"""Proposes the next sequence of tokens based on n-gram pattern num_tokens_no_spec: np.ndarray,
matching in the context. The function finds matches of the last n token_ids_cpu: np.ndarray,
tokens in the previous context, and returns k tokens that followed spec_decode_unsupported_reqs: set,
that match. ) -> list[list[int]]:
Args:
context_token_ids: Numpy array of token IDs representing the
context sequence.
Returns: # find which requests need ngram proposals
np.ndarray: The sequence of tokens that followed valid_ngram_requests = []
the matched n-gram in the context. for i, sampled_ids in enumerate(sampled_token_ids):
None: If no matching n-gram pattern is found. num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
Example: # Skip requests that require sampling parameters that are not
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and # supported with speculative decoding.
k = 4: req_id = req_ids[i]
- The last 3 (= max_n) tokens [4,2,3] cannot find a match. if req_id in spec_decode_unsupported_reqs:
- The last 2 tokens [2,3] will be matched against the previous continue
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that num_tokens = num_tokens_no_spec[i]
followed that pattern. Here we will return [4,2,3] because if num_tokens >= self.max_model_len:
we only have three tokens after the match. # Skip requests that have already reached the max model length.
""" continue
# TODO(woosuk): Optimize this.
return _find_longest_matched_ngram_and_propose_tokens( valid_ngram_requests.append(i)
origin_tokens=context_token_ids,
min_ngram=self.min_n, draft_token_ids = self.batch_propose(
max_ngram=self.max_n, len(sampled_token_ids),
max_model_len=self.max_model_len, valid_ngram_requests,
k=self.k) num_tokens_no_spec,
token_ids_cpu,
)
return draft_token_ids
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
# No model to load. # No model to load.
pass pass
@njit(parallel=True)
def batch_propose_numba(valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray, min_n: int, max_n: int,
max_model_len: int, k: int,
valid_ngram_draft: np.ndarray,
valid_ngram_num_drafts: np.ndarray):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
num_tokens = num_tokens_no_spec[idx]
context_token_ids = token_ids_cpu[idx, :num_tokens]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=min_n,
max_ngram=max_n,
max_model_len=max_model_len,
k=k)
valid_ngram_num_drafts[i] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output
@jit(nopython=True) @jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens( def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray,
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, min_ngram: int,
max_model_len: int, k: int) -> Optional[np.ndarray]: max_ngram: int,
max_model_len: int,
k: int) -> np.ndarray:
""" """
Find the longest n-gram which matches the suffix of the given tokens Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive). whose length is within [min_ngram, max_ngram] (inclusive).
@ -84,12 +203,12 @@ def _find_longest_matched_ngram_and_propose_tokens(
# Do not generate draft tokens is context is shorter than minimum n-gram # Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0] total_token = origin_tokens.shape[0]
if total_token < min_ngram: if total_token < min_ngram:
return None return np.empty((0, ), dtype=origin_tokens.dtype)
# Do not generate draft tokens beyond the max model length. # Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token) k = min(k, max_model_len - total_token)
if k <= 0: if k <= 0:
return None return np.empty((0, ), dtype=origin_tokens.dtype)
# Flip tokens, and the goal become to find longest ngram # Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with # on the rightmost position which matches the prefix with
@ -146,7 +265,7 @@ def _find_longest_matched_ngram_and_propose_tokens(
if longest_ngram < min_ngram: if longest_ngram < min_ngram:
# No valid ngram is found # No valid ngram is found
return None return np.empty((0, ), dtype=origin_tokens.dtype)
# Flip the position back, so in origin_tokens, # Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]

View File

@ -2404,8 +2404,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids( draft_token_ids = self.drafter.propose(
sampled_token_ids) sampled_token_ids, self.input_batch.req_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
self.input_batch.spec_decode_unsupported_reqs)
elif self.speculative_config.method == "medusa": elif self.speculative_config.method == "medusa":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer) assert isinstance(self.drafter, MedusaProposer)
@ -2515,41 +2518,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
return draft_token_ids return draft_token_ids
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = self.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"} allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items(): for config_name, config_overrides in overrides.items():