[Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. (#24986)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Ekagra Ranjan 2025-09-25 18:22:03 -04:00 committed by yewentao256
parent b558c3a8b7
commit f3a478b55e
5 changed files with 383 additions and 109 deletions

View File

@ -1,17 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from unittest import mock
import numpy as np
from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def main(args):
def benchmark_propose(args):
rows = []
for max_ngram in args.max_ngram:
collector = TimeCollector(TimeCollector.US)
@ -69,10 +83,88 @@ def main(args):
)
def benchmark_batched_propose(args):
NUM_SPECULATIVE_TOKENS_NGRAM = 10
PROMPT_LOOKUP_MIN = 5
PROMPT_LOOKUP_MAX = 15
MAX_MODEL_LEN = int(1e7)
DEVICE = current_platform.device_type
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="ngram",
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
prompt_lookup_max=PROMPT_LOOKUP_MAX,
prompt_lookup_min=PROMPT_LOOKUP_MIN,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(),
)
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
with mock.patch(
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
):
runner = GPUModelRunner(vllm_config, DEVICE)
# hack max model len
runner.max_model_len = MAX_MODEL_LEN
runner.drafter.max_model_len = MAX_MODEL_LEN
dummy_input_batch = InputBatch(
max_num_reqs=args.num_req,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=args.num_req * args.num_token,
device=DEVICE,
pin_memory=False,
vocab_size=256000,
block_sizes=[16],
)
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
dummy_input_batch.spec_decode_unsupported_reqs = ()
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
dummy_input_batch.token_ids_cpu = np.random.randint(
0, 20, (args.num_req, args.num_token)
)
runner.input_batch = dummy_input_batch
sampled_token_ids = [[0]] * args.num_req
print("Starting benchmark")
# first run is warmup so ignore it
for _ in range(args.num_iteration):
start = time.time()
runner.drafter.propose(
sampled_token_ids,
dummy_input_batch.req_ids,
dummy_input_batch.num_tokens_no_spec,
dummy_input_batch.token_ids_cpu,
dummy_input_batch.spec_decode_unsupported_reqs,
)
end = time.time()
print(f"Iteration time (s): {end - start}")
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of N-gram speculative decode drafting"
)
parser.add_argument(
"--batched", action="store_true", help="consider time to prepare batch"
) # noqa: E501
parser.add_argument(
"--num-iteration",
type=int,
@ -105,8 +197,17 @@ def invoke_main() -> None:
help="Number of speculative tokens to generate",
)
args = parser.parse_args()
main(args)
if not args.batched:
benchmark_propose(args)
else:
benchmark_batched_propose(args)
"""
# Example command lines:
# time python3 benchmarks/benchmark_ngram_proposer.py
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
""" # noqa: E501
if __name__ == "__main__":
invoke_main() # pragma: no cover

View File

@ -9,11 +9,13 @@ from vllm.v1.spec_decode.ngram_proposer import (
def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2) is None
result = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2)
assert len(result) == 0
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():
def test_ngram_proposer():
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m")
return NgramProposer(
@ -75,36 +77,120 @@ def test_ngram_proposer():
)))
# No match.
result = ngram_proposer(
min_n=2, max_n=2,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
assert result is None
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram.
result = ngram_proposer(
min_n=4, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert result is None
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram but match for 3-gram.
result = ngram_proposer(
min_n=3, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert np.array_equal(result, np.array([4, 1]))
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[4, 1]]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
# Match for 2-gram and 3-gram, but not 4-gram.
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
# Multiple 3-gram matched, but always pick the first one.
result = ngram_proposer(
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
assert np.array_equal(result, np.array([100, 1]))
token_ids_cpu = np.array(
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[100, 1]]))
# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# check multibatch input
# first request has 5 tokens and a match
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1]))
assert np.array_equal(result[1], np.array([]))
# test if 0 threads available: can happen if TP size > CPU count
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
ngram_proposer.num_numba_thread_available = 0
# set max_model_len to 2 * threshold to ensure multithread is used
num_tokens_threshold = ngram_proposer.num_tokens_threshold
ngram_proposer.max_model_len = 2 * num_tokens_threshold
# using multibatch test
middle_integer = num_tokens_threshold // 2
input_1 = [_ for _ in range(num_tokens_threshold)]
input_1 += [middle_integer, middle_integer + 1]
input_2 = [-1] * len(input_1)
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0],
np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[1], np.array([]))

View File

@ -17,7 +17,7 @@ PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
MAX_SPEC_LEN = 128
class RejectionSampler(nn.Module):

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import os
import numpy as np
from numba import jit
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
@ -26,55 +26,174 @@ class NgramProposer:
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Pre-allocate buffers for numba batch propose.
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k),
dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
# Divide by 2 to use physical cores
# and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else:
self.num_numba_thread_available = 1
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
set())
def batch_propose(
self,
num_requests: int,
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
) -> list[list[int]]:
"""Batch version of ngram proposer using numba for acceleration.
Args:
valid_ngram_requests:
Set of indices of requests that need ngram proposals.
num_tokens_no_spec:
Numpy array of shape (batch_size,) representing the number
of tokens without speculative tokens for each request.
token_ids_cpu:
Numpy array of shape (batch_size, max_model_len)
representing the token IDs for each request.
Returns:
list[list[int]]:
A list where each element is a list of proposed
token IDs for the corresponding request.
"""
draft_token_ids: list[list[int]] = []
# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available,
num_ngram_requests))
set_num_threads(final_num_threads)
else:
set_num_threads(1)
batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
token_ids_cpu, self.min_n, self.max_n,
self.max_model_len, self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts)
# Restore original number of threads.
set_num_threads(original_num_numba_threads)
for i in range(num_requests):
if i in valid_ngram_requests and \
self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(self.valid_ngram_draft[
i, :self.valid_ngram_num_drafts[i]].tolist())
else:
draft_token_ids.append([])
return draft_token_ids
def propose(
self,
context_token_ids: np.ndarray,
) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
that match.
Args:
context_token_ids: Numpy array of token IDs representing the
context sequence.
sampled_token_ids: list[list[int]],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set,
) -> list[list[int]]:
Returns:
np.ndarray: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
Example:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO(woosuk): Optimize this.
return _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=self.min_n,
max_ngram=self.max_n,
max_model_len=self.max_model_len,
k=self.k)
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass
@njit(parallel=True)
def batch_propose_numba(valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray, min_n: int, max_n: int,
max_model_len: int, k: int,
valid_ngram_draft: np.ndarray,
valid_ngram_num_drafts: np.ndarray):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
num_tokens = num_tokens_no_spec[idx]
context_token_ids = token_ids_cpu[idx, :num_tokens]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=min_n,
max_ngram=max_n,
max_model_len=max_model_len,
k=k)
valid_ngram_num_drafts[i] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
max_model_len: int, k: int) -> Optional[np.ndarray]:
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray,
min_ngram: int,
max_ngram: int,
max_model_len: int,
k: int) -> np.ndarray:
"""
Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive).
@ -84,12 +203,12 @@ def _find_longest_matched_ngram_and_propose_tokens(
# Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0]
if total_token < min_ngram:
return None
return np.empty((0, ), dtype=origin_tokens.dtype)
# Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token)
if k <= 0:
return None
return np.empty((0, ), dtype=origin_tokens.dtype)
# Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with
@ -146,7 +265,7 @@ def _find_longest_matched_ngram_and_propose_tokens(
if longest_ngram < min_ngram:
# No valid ngram is found
return None
return np.empty((0, ), dtype=origin_tokens.dtype)
# Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]

View File

@ -2404,8 +2404,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config.method == "ngram":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
draft_token_ids = self.drafter.propose(
sampled_token_ids, self.input_batch.req_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
self.input_batch.spec_decode_unsupported_reqs)
elif self.speculative_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
@ -2515,41 +2518,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
return draft_token_ids
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = self.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():