[Core] [N-gram SD Optimization][1/n] Propose tokens with a single KMP (#22437)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-08-13 14:44:06 -07:00 committed by GitHub
parent 4e8614e88b
commit 31a500c86f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 388 additions and 206 deletions

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.utils import FlexibleArgumentParser
from vllm.v1.core.block_pool import BlockPool
def main(args):
rows = []
for allocate_block in args.allocate_blocks:
# Enforce a GC collect ahead to minimize the impact among runs
gc.collect()
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
get_blocks_times = TimeCollector(TimeCollector.US)
free_blocks_times = TimeCollector(TimeCollector.US)
for _ in range(args.num_iteration):
with get_blocks_times:
blocks = block_pool.get_new_blocks(allocate_block)
with free_blocks_times:
block_pool.free_blocks(blocks)
rows.append(
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
+ get_blocks_times.dump_avg_max()
+ free_blocks_times.dump_avg_max()
)
print(
tabulate(
rows,
headers=[
"Iterations",
"Total\nBlocks",
"Allocated\nBlocks",
"Get Blocks\nAvg (us)",
"Get Blocks\nMax (us)",
"Free Blocks\nAvg (us)",
"Free Blocks\nMax (us)",
],
tablefmt="grid",
floatfmt=".3f",
)
)
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of BlockPool for KV Cache."
)
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
parser.add_argument(
"--num-iteration",
type=int,
default=1000,
help="Number of iterations to run to stablize final data readings",
)
parser.add_argument(
"--allocate-blocks",
type=int,
nargs="*",
default=[10, 50, 100, 500, 1000],
help="Number of blocks to allocate",
)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
invoke_main() # pragma: no cover

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import numpy as np
from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
def main(args):
rows = []
for max_ngram in args.max_ngram:
collector = TimeCollector(TimeCollector.US)
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
max_model_len=args.num_token + args.num_spec_token,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False,
)
proposer = NgramProposer(
vllm_config=VllmConfig(
model_config=model_config,
speculative_config=SpeculativeConfig(
prompt_lookup_min=args.min_ngram,
prompt_lookup_max=max_ngram,
num_speculative_tokens=args.num_spec_token,
method="ngram",
),
)
)
# Warm up
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
gc.collect()
for _ in range(args.num_iteration):
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
with collector:
for i in range(args.num_req):
proposer.propose(tokens[i, :])
rows.append(
[args.num_req, args.num_token, args.min_ngram, max_ngram]
+ collector.dump_avg_max()
)
print(
tabulate(
rows,
headers=[
"# Request",
"# Token",
"Min Ngram",
"Max Ngram",
"Avg (us)",
"Max (us)",
],
tablefmt="grid",
floatfmt=".3f",
)
)
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of N-gram speculative decode drafting"
)
parser.add_argument(
"--num-iteration",
type=int,
default=100,
help="Number of iterations to run to stablize final data readings",
)
parser.add_argument(
"--num-req", type=int, default=128, help="Number of requests in the batch"
)
parser.add_argument(
"--num-token", type=int, default=1500, help="Number of tokens for each request"
)
parser.add_argument(
"--min-ngram",
type=int,
default=3,
help="Minimum n-gram to match",
)
parser.add_argument(
"--max-ngram",
type=int,
nargs="*",
default=[5, 7, 10, 15, 20],
help="Maximum n-gram to match",
)
parser.add_argument(
"--num-spec-token",
type=int,
default=3,
help="Number of speculative tokens to generate",
)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
invoke_main() # pragma: no cover

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import math
import os
from typing import Any
import time
from types import TracebackType
from typing import Any, Optional, Union
def convert_to_pytorch_benchmark_format(
@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None:
cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
)
# Collect time and generate time metrics
#
# Example Usage:
# collector = TimeCollector(TimeCollector.US)
# for _ in range(total_iteration):
# with collector:
# ...
# collector.dump_avg_max()
class TimeCollector:
NS: int = 1
US: int = NS * 1000
MS: int = US * 1000
S: int = MS * 1000
def __init__(self, scale: int) -> None:
self.cnt: int = 0
self._sum: int = 0
self._max: Optional[int] = None
self.scale = scale
self.start_time: int = time.monotonic_ns()
def collect(self, v: int) -> None:
self.cnt += 1
self._sum += v
if self._max is None:
self._max = v
else:
self._max = max(self._max, v)
def avg(self) -> Union[float, str]:
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
def max(self) -> Union[float, str]:
return self._max / self.scale if self._max else "N/A"
def dump_avg_max(self) -> list[Union[float, str]]:
return [self.avg(), self.max()]
def __enter__(self) -> None:
self.start_time = time.monotonic_ns()
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
) -> None:
self.collect(time.monotonic_ns() - self.start_time)

View File

@ -1,108 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from typing import Optional
from tabulate import tabulate
from vllm.utils import FlexibleArgumentParser
from vllm.v1.core.block_pool import BlockPool
class Metric:
def __init__(self) -> None:
self.cnt: int = 0
self.sum_v: int = 0
self.max_v: Optional[int] = None
def update(self, v: int) -> None:
self.cnt += 1
self.sum_v += v
if self.max_v is None:
self.max_v = v
else:
self.max_v = max(self.max_v, v)
def avg_v(self) -> float:
return self.sum_v * 1.0 / self.cnt
def main(args):
rows = []
for allocate_block in args.allocate_blocks:
# Enforce a GC collect ahead to minimize the impact among runs
gc.collect()
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
get_blocks_metric: Metric = Metric()
free_blocks_metric: Metric = Metric()
for _ in range(args.num_iteration):
t1 = time.monotonic_ns()
blocks = block_pool.get_new_blocks(allocate_block)
t2 = time.monotonic_ns()
block_pool.free_blocks(blocks)
t3 = time.monotonic_ns()
get_blocks_metric.update(t2 - t1)
free_blocks_metric.update(t3 - t2)
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
rows.append(
[
get_blocks_metric.cnt,
args.num_gpu_blocks,
allocate_block,
get_blocks_metric.avg_v() / 1000000,
get_blocks_metric.max_v / 1000000.0,
free_blocks_metric.avg_v() / 1000000,
free_blocks_metric.max_v / 1000000.0,
]
)
else:
print(
"No valid metrics found."
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
)
print(
tabulate(
rows,
headers=[
"Iterations",
"Total\nBlocks",
"Allocated\nBlocks",
"Get Blocks\nAvg (ms)",
"Get Blocks\nMax (ms)",
"Free Blocks\nAvg (ms)",
"Free Blocks\nMax (ms)",
],
tablefmt="grid",
floatfmt=".6f",
)
)
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of BlockPool for KV Cache."
)
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
parser.add_argument(
"--num-iteration",
type=int,
default=1000,
help="Number of iterations to run to stablize final data readings",
)
parser.add_argument(
"--allocate-blocks",
type=int,
nargs="*",
default=[10, 50, 100, 500, 1000],
help="Number of blocks to allocate",
)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
invoke_main() # pragma: no cover

View File

@ -1,43 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array)
from vllm.v1.spec_decode.ngram_proposer import (
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
def test_kmp_lps_array():
np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([]))
np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0]))
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])),
np.array([0, 1, 2]))
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])),
np.array([0, 0, 0, 0]))
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])),
np.array([0, 0, 1, 2, 0]))
def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2) is None
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2), np.array([4, 1]))
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=2), np.array([4, 1]))
def test_find_subarray_kmp():
X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert _find_subarray_kmp(X, 2, 2) is None
X = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
np.array([4, 1, 2]))
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4,
1]))
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
np.array([4, 1, 2]))
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4,
1]))
X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
np.array([4, 1, 2]))
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
# Return on the first match
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
np.array([6, 2, 3]))
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=2), np.array([6, 2]))
def test_ngram_proposer():
@ -56,27 +76,35 @@ def test_ngram_proposer():
# No match.
result = ngram_proposer(
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
min_n=2, max_n=2,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
assert result is None
# No match for 4-gram.
result = ngram_proposer(
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=4, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert result is None
# No match for 4-gram but match for 3-gram.
result = ngram_proposer(
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=3, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = ngram_proposer(3, 4, 2).propose(
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
result = ngram_proposer(
2, 4,
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
# Multiple 3-gram matched, but always pick the first one.
result = ngram_proposer(
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
assert np.array_equal(result, np.array([100, 1]))

View File

@ -11,6 +11,10 @@ from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
@ -54,17 +58,13 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
if k <= 0:
return None
# TODO(woosuk): Optimize this.
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
return _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=self.min_n,
max_ngram=self.max_n,
max_model_len=self.max_model_len,
k=self.k)
def load_model(self, *args, **kwargs):
# No model to load.
@ -72,61 +72,86 @@ class NgramProposer:
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
def _find_longest_matched_ngram_and_propose_tokens(
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
max_model_len: int, k: int) -> Optional[np.ndarray]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive).
If found, we will extract k right after the matched ngram.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
# Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0]
if total_token < min_ngram:
return None
# Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token)
if k <= 0:
return None
# Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with
# length [min_n, max_n] (inclusive).
tokens = origin_tokens[::-1]
# Longest prefix (not including itself) which is a suffix of
# the current position.
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
#
# As ngram is capped by max_ngram to save memory, we only need to
# store lps for the first max_ngram prefix.
lps = np.zeros(max_ngram, dtype=np.int32)
longest_ngram = 0
position = 0
# lps[0] always equal to 0, we starts with index 1
prev_lps = 0
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
while i < total_token:
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
if tokens[prev_lps] == tokens[i]:
# Token match: tokens[:prev_lps+1] is the longest prefix as
# a suffix of tokens[:i+1]
prev_lps += 1
lps[i] = prev_lps
# Check if we found a longer valid ngram.
#
# Update position when longest_ngram matched prev_lps,
# as we want to get the target n-gram of the earliest position
# in the original tokens (i.e.
# latest position in the reversed tokens)
if prev_lps >= longest_ngram:
longest_ngram = prev_lps
position = i
if i < max_ngram:
# Store LPS for the first max_ngram prefix
lps[i] = prev_lps
if prev_lps == max_ngram:
# When prev_lps reached max_ngram, update prev_lps
# to lps[max_ngram-1] to avoid matching ngram
# longer than max_ngram
prev_lps = lps[max_ngram - 1]
i += 1
elif prev_lps != 0:
# Token mismatch: try the second longest prefix
# among all suffix of tokens[:i],
# which is the longest prefix of tokens[:prev_lps]
prev_lps = lps[prev_lps - 1]
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
return lps
@jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = _kmp_lps_array(pattern)
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
# Token mismatch, and no more prefix (except empty string)
# as a suffix of tokens[:i]
i += 1
j += 1
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
if longest_ngram < min_ngram:
# No valid ngram is found
return None
# Y not found
return None
# Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
# is the matched ngram, so we should start drafting tokens from
# total_token-1-position+longest_ngram
start_position = total_token - 1 - position + longest_ngram
k = min(k, total_token - start_position)
return origin_tokens[start_position:start_position + k]