mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[Core] [N-gram SD Optimization][1/n] Propose tokens with a single KMP (#22437)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
4e8614e88b
commit
31a500c86f
74
benchmarks/benchmark_block_pool.py
Normal file
74
benchmarks/benchmark_block_pool.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for allocate_block in args.allocate_blocks:
|
||||||
|
# Enforce a GC collect ahead to minimize the impact among runs
|
||||||
|
gc.collect()
|
||||||
|
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||||
|
|
||||||
|
get_blocks_times = TimeCollector(TimeCollector.US)
|
||||||
|
free_blocks_times = TimeCollector(TimeCollector.US)
|
||||||
|
for _ in range(args.num_iteration):
|
||||||
|
with get_blocks_times:
|
||||||
|
blocks = block_pool.get_new_blocks(allocate_block)
|
||||||
|
with free_blocks_times:
|
||||||
|
block_pool.free_blocks(blocks)
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
|
||||||
|
+ get_blocks_times.dump_avg_max()
|
||||||
|
+ free_blocks_times.dump_avg_max()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows,
|
||||||
|
headers=[
|
||||||
|
"Iterations",
|
||||||
|
"Total\nBlocks",
|
||||||
|
"Allocated\nBlocks",
|
||||||
|
"Get Blocks\nAvg (us)",
|
||||||
|
"Get Blocks\nMax (us)",
|
||||||
|
"Free Blocks\nAvg (us)",
|
||||||
|
"Free Blocks\nMax (us)",
|
||||||
|
],
|
||||||
|
tablefmt="grid",
|
||||||
|
floatfmt=".3f",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_main() -> None:
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance of BlockPool for KV Cache."
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iteration",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of iterations to run to stablize final data readings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allocate-blocks",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
default=[10, 50, 100, 500, 1000],
|
||||||
|
help="Number of blocks to allocate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
invoke_main() # pragma: no cover
|
||||||
112
benchmarks/benchmark_ngram_proposer.py
Normal file
112
benchmarks/benchmark_ngram_proposer.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
|
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for max_ngram in args.max_ngram:
|
||||||
|
collector = TimeCollector(TimeCollector.US)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
task="generate",
|
||||||
|
max_model_len=args.num_token + args.num_spec_token,
|
||||||
|
tokenizer="facebook/opt-125m",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
dtype="auto",
|
||||||
|
seed=None,
|
||||||
|
trust_remote_code=False,
|
||||||
|
)
|
||||||
|
proposer = NgramProposer(
|
||||||
|
vllm_config=VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
speculative_config=SpeculativeConfig(
|
||||||
|
prompt_lookup_min=args.min_ngram,
|
||||||
|
prompt_lookup_max=max_ngram,
|
||||||
|
num_speculative_tokens=args.num_spec_token,
|
||||||
|
method="ngram",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
for _ in range(args.num_iteration):
|
||||||
|
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
||||||
|
with collector:
|
||||||
|
for i in range(args.num_req):
|
||||||
|
proposer.propose(tokens[i, :])
|
||||||
|
rows.append(
|
||||||
|
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
||||||
|
+ collector.dump_avg_max()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows,
|
||||||
|
headers=[
|
||||||
|
"# Request",
|
||||||
|
"# Token",
|
||||||
|
"Min Ngram",
|
||||||
|
"Max Ngram",
|
||||||
|
"Avg (us)",
|
||||||
|
"Max (us)",
|
||||||
|
],
|
||||||
|
tablefmt="grid",
|
||||||
|
floatfmt=".3f",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_main() -> None:
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iteration",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of iterations to run to stablize final data readings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Minimum n-gram to match",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-ngram",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
default=[5, 7, 10, 15, 20],
|
||||||
|
help="Maximum n-gram to match",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-spec-token",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of speculative tokens to generate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
invoke_main() # pragma: no cover
|
||||||
@ -1,11 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
import time
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pytorch_benchmark_format(
|
def convert_to_pytorch_benchmark_format(
|
||||||
@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None:
|
|||||||
cls=InfEncoder,
|
cls=InfEncoder,
|
||||||
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Collect time and generate time metrics
|
||||||
|
#
|
||||||
|
# Example Usage:
|
||||||
|
# collector = TimeCollector(TimeCollector.US)
|
||||||
|
# for _ in range(total_iteration):
|
||||||
|
# with collector:
|
||||||
|
# ...
|
||||||
|
# collector.dump_avg_max()
|
||||||
|
class TimeCollector:
|
||||||
|
NS: int = 1
|
||||||
|
US: int = NS * 1000
|
||||||
|
MS: int = US * 1000
|
||||||
|
S: int = MS * 1000
|
||||||
|
|
||||||
|
def __init__(self, scale: int) -> None:
|
||||||
|
self.cnt: int = 0
|
||||||
|
self._sum: int = 0
|
||||||
|
self._max: Optional[int] = None
|
||||||
|
self.scale = scale
|
||||||
|
self.start_time: int = time.monotonic_ns()
|
||||||
|
|
||||||
|
def collect(self, v: int) -> None:
|
||||||
|
self.cnt += 1
|
||||||
|
self._sum += v
|
||||||
|
if self._max is None:
|
||||||
|
self._max = v
|
||||||
|
else:
|
||||||
|
self._max = max(self._max, v)
|
||||||
|
|
||||||
|
def avg(self) -> Union[float, str]:
|
||||||
|
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
|
||||||
|
|
||||||
|
def max(self) -> Union[float, str]:
|
||||||
|
return self._max / self.scale if self._max else "N/A"
|
||||||
|
|
||||||
|
def dump_avg_max(self) -> list[Union[float, str]]:
|
||||||
|
return [self.avg(), self.max()]
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
self.start_time = time.monotonic_ns()
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
exc_traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
self.collect(time.monotonic_ns() - self.start_time)
|
||||||
|
|||||||
@ -1,108 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from tabulate import tabulate
|
|
||||||
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
|
||||||
|
|
||||||
|
|
||||||
class Metric:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.cnt: int = 0
|
|
||||||
self.sum_v: int = 0
|
|
||||||
self.max_v: Optional[int] = None
|
|
||||||
|
|
||||||
def update(self, v: int) -> None:
|
|
||||||
self.cnt += 1
|
|
||||||
self.sum_v += v
|
|
||||||
if self.max_v is None:
|
|
||||||
self.max_v = v
|
|
||||||
else:
|
|
||||||
self.max_v = max(self.max_v, v)
|
|
||||||
|
|
||||||
def avg_v(self) -> float:
|
|
||||||
return self.sum_v * 1.0 / self.cnt
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
rows = []
|
|
||||||
for allocate_block in args.allocate_blocks:
|
|
||||||
# Enforce a GC collect ahead to minimize the impact among runs
|
|
||||||
gc.collect()
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
|
||||||
|
|
||||||
get_blocks_metric: Metric = Metric()
|
|
||||||
free_blocks_metric: Metric = Metric()
|
|
||||||
for _ in range(args.num_iteration):
|
|
||||||
t1 = time.monotonic_ns()
|
|
||||||
blocks = block_pool.get_new_blocks(allocate_block)
|
|
||||||
t2 = time.monotonic_ns()
|
|
||||||
block_pool.free_blocks(blocks)
|
|
||||||
t3 = time.monotonic_ns()
|
|
||||||
get_blocks_metric.update(t2 - t1)
|
|
||||||
free_blocks_metric.update(t3 - t2)
|
|
||||||
|
|
||||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
|
||||||
rows.append(
|
|
||||||
[
|
|
||||||
get_blocks_metric.cnt,
|
|
||||||
args.num_gpu_blocks,
|
|
||||||
allocate_block,
|
|
||||||
get_blocks_metric.avg_v() / 1000000,
|
|
||||||
get_blocks_metric.max_v / 1000000.0,
|
|
||||||
free_blocks_metric.avg_v() / 1000000,
|
|
||||||
free_blocks_metric.max_v / 1000000.0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"No valid metrics found."
|
|
||||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
tabulate(
|
|
||||||
rows,
|
|
||||||
headers=[
|
|
||||||
"Iterations",
|
|
||||||
"Total\nBlocks",
|
|
||||||
"Allocated\nBlocks",
|
|
||||||
"Get Blocks\nAvg (ms)",
|
|
||||||
"Get Blocks\nMax (ms)",
|
|
||||||
"Free Blocks\nAvg (ms)",
|
|
||||||
"Free Blocks\nMax (ms)",
|
|
||||||
],
|
|
||||||
tablefmt="grid",
|
|
||||||
floatfmt=".6f",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def invoke_main() -> None:
|
|
||||||
parser = FlexibleArgumentParser(
|
|
||||||
description="Benchmark the performance of BlockPool for KV Cache."
|
|
||||||
)
|
|
||||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-iteration",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of iterations to run to stablize final data readings",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--allocate-blocks",
|
|
||||||
type=int,
|
|
||||||
nargs="*",
|
|
||||||
default=[10, 50, 100, 500, 1000],
|
|
||||||
help="Number of blocks to allocate",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
invoke_main() # pragma: no cover
|
|
||||||
@ -1,43 +1,63 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||||
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
from vllm.v1.spec_decode.ngram_proposer import (
|
||||||
_find_subarray_kmp,
|
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
|
||||||
_kmp_lps_array)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kmp_lps_array():
|
def test_find_longest_matched_ngram_and_propose_tokens():
|
||||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([]))
|
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0]))
|
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])),
|
min_ngram=2,
|
||||||
np.array([0, 1, 2]))
|
max_ngram=2,
|
||||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])),
|
max_model_len=1024,
|
||||||
np.array([0, 0, 0, 0]))
|
k=2) is None
|
||||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])),
|
|
||||||
np.array([0, 0, 1, 2, 0]))
|
|
||||||
|
|
||||||
|
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
|
min_ngram=2,
|
||||||
|
max_ngram=2,
|
||||||
|
max_model_len=1024,
|
||||||
|
k=3),
|
||||||
|
np.array([4, 1, 2]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
|
min_ngram=2,
|
||||||
|
max_ngram=2,
|
||||||
|
max_model_len=1024,
|
||||||
|
k=2), np.array([4, 1]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
|
min_ngram=1,
|
||||||
|
max_ngram=1,
|
||||||
|
max_model_len=1024,
|
||||||
|
k=3),
|
||||||
|
np.array([4, 1, 2]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
|
min_ngram=1,
|
||||||
|
max_ngram=1,
|
||||||
|
max_model_len=1024,
|
||||||
|
k=2), np.array([4, 1]))
|
||||||
|
|
||||||
def test_find_subarray_kmp():
|
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||||
X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
np.testing.assert_array_equal(
|
||||||
assert _find_subarray_kmp(X, 2, 2) is None
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
X = np.array([1, 2, 3, 4, 1, 2, 3])
|
min_ngram=2,
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
|
max_ngram=2,
|
||||||
np.array([4, 1, 2]))
|
max_model_len=1024,
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4,
|
k=3),
|
||||||
1]))
|
np.array([4, 1, 2]))
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
|
||||||
np.array([4, 1, 2]))
|
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4,
|
|
||||||
1]))
|
|
||||||
X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
|
|
||||||
np.array([4, 1, 2]))
|
|
||||||
# Return on the first match
|
# Return on the first match
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
np.testing.assert_array_equal(
|
||||||
np.array([6, 2, 3]))
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||||
|
min_ngram=1,
|
||||||
|
max_ngram=1,
|
||||||
|
max_model_len=1024,
|
||||||
|
k=2), np.array([6, 2]))
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_proposer():
|
def test_ngram_proposer():
|
||||||
@ -56,27 +76,35 @@ def test_ngram_proposer():
|
|||||||
|
|
||||||
# No match.
|
# No match.
|
||||||
result = ngram_proposer(
|
result = ngram_proposer(
|
||||||
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
min_n=2, max_n=2,
|
||||||
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# No match for 4-gram.
|
# No match for 4-gram.
|
||||||
result = ngram_proposer(
|
result = ngram_proposer(
|
||||||
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
min_n=4, max_n=4,
|
||||||
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# No match for 4-gram but match for 3-gram.
|
# No match for 4-gram but match for 3-gram.
|
||||||
result = ngram_proposer(
|
result = ngram_proposer(
|
||||||
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
min_n=3, max_n=4,
|
||||||
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||||
assert np.array_equal(result, np.array([4, 1]))
|
assert np.array_equal(result, np.array([4, 1]))
|
||||||
|
|
||||||
# Match for both 4-gram and 3-gram.
|
# Match for both 4-gram and 3-gram.
|
||||||
# In this case, the proposer should return the 4-gram match.
|
# In this case, the proposer should return the 4-gram match.
|
||||||
result = ngram_proposer(3, 4, 2).propose(
|
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
||||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||||
|
|
||||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||||
result = ngram_proposer(
|
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||||
2, 4,
|
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||||
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
|
||||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||||
|
|
||||||
|
# Multiple 3-gram matched, but always pick the first one.
|
||||||
|
result = ngram_proposer(
|
||||||
|
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
|
||||||
|
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
|
||||||
|
assert np.array_equal(result, np.array([100, 1]))
|
||||||
|
|||||||
@ -11,6 +11,10 @@ from vllm.config import VllmConfig
|
|||||||
class NgramProposer:
|
class NgramProposer:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
|
assert vllm_config.speculative_config is not None
|
||||||
|
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
||||||
|
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
||||||
|
|
||||||
# Minimum length of the n-gram to match.
|
# Minimum length of the n-gram to match.
|
||||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||||
# Maximum length of the n-gram to match.
|
# Maximum length of the n-gram to match.
|
||||||
@ -54,17 +58,13 @@ class NgramProposer:
|
|||||||
followed that pattern. Here we will return [4,2,3] because
|
followed that pattern. Here we will return [4,2,3] because
|
||||||
we only have three tokens after the match.
|
we only have three tokens after the match.
|
||||||
"""
|
"""
|
||||||
# Do not generate draft tokens beyond the max model length.
|
|
||||||
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
|
|
||||||
if k <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# TODO(woosuk): Optimize this.
|
# TODO(woosuk): Optimize this.
|
||||||
for n in range(self.max_n, self.min_n - 1, -1):
|
return _find_longest_matched_ngram_and_propose_tokens(
|
||||||
result = _find_subarray_kmp(context_token_ids, n, k)
|
origin_tokens=context_token_ids,
|
||||||
if result is not None:
|
min_ngram=self.min_n,
|
||||||
return result
|
max_ngram=self.max_n,
|
||||||
return None
|
max_model_len=self.max_model_len,
|
||||||
|
k=self.k)
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
# No model to load.
|
# No model to load.
|
||||||
@ -72,61 +72,86 @@ class NgramProposer:
|
|||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
@jit(nopython=True)
|
||||||
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
|
def _find_longest_matched_ngram_and_propose_tokens(
|
||||||
|
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
|
||||||
|
max_model_len: int, k: int) -> Optional[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Build the lps (longest proper prefix which is also suffix)
|
Find the longest n-gram which matches the suffix of the given tokens
|
||||||
array for the pattern.
|
whose length is within [min_ngram, max_ngram] (inclusive).
|
||||||
|
|
||||||
|
If found, we will extract k right after the matched ngram.
|
||||||
"""
|
"""
|
||||||
lps = np.zeros(len(pattern), dtype=np.int32)
|
# Do not generate draft tokens is context is shorter than minimum n-gram
|
||||||
prev_lps = 0 # length of the previous longest prefix suffix
|
total_token = origin_tokens.shape[0]
|
||||||
|
if total_token < min_ngram:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Do not generate draft tokens beyond the max model length.
|
||||||
|
k = min(k, max_model_len - total_token)
|
||||||
|
if k <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Flip tokens, and the goal become to find longest ngram
|
||||||
|
# on the rightmost position which matches the prefix with
|
||||||
|
# length [min_n, max_n] (inclusive).
|
||||||
|
tokens = origin_tokens[::-1]
|
||||||
|
|
||||||
|
# Longest prefix (not including itself) which is a suffix of
|
||||||
|
# the current position.
|
||||||
|
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
|
||||||
|
#
|
||||||
|
# As ngram is capped by max_ngram to save memory, we only need to
|
||||||
|
# store lps for the first max_ngram prefix.
|
||||||
|
lps = np.zeros(max_ngram, dtype=np.int32)
|
||||||
|
|
||||||
|
longest_ngram = 0
|
||||||
|
position = 0
|
||||||
|
|
||||||
|
# lps[0] always equal to 0, we starts with index 1
|
||||||
|
prev_lps = 0
|
||||||
i = 1
|
i = 1
|
||||||
|
while i < total_token:
|
||||||
while i < len(pattern):
|
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
|
||||||
if pattern[i] == pattern[prev_lps]:
|
if tokens[prev_lps] == tokens[i]:
|
||||||
|
# Token match: tokens[:prev_lps+1] is the longest prefix as
|
||||||
|
# a suffix of tokens[:i+1]
|
||||||
prev_lps += 1
|
prev_lps += 1
|
||||||
lps[i] = prev_lps
|
# Check if we found a longer valid ngram.
|
||||||
|
#
|
||||||
|
# Update position when longest_ngram matched prev_lps,
|
||||||
|
# as we want to get the target n-gram of the earliest position
|
||||||
|
# in the original tokens (i.e.
|
||||||
|
# latest position in the reversed tokens)
|
||||||
|
if prev_lps >= longest_ngram:
|
||||||
|
longest_ngram = prev_lps
|
||||||
|
position = i
|
||||||
|
if i < max_ngram:
|
||||||
|
# Store LPS for the first max_ngram prefix
|
||||||
|
lps[i] = prev_lps
|
||||||
|
if prev_lps == max_ngram:
|
||||||
|
# When prev_lps reached max_ngram, update prev_lps
|
||||||
|
# to lps[max_ngram-1] to avoid matching ngram
|
||||||
|
# longer than max_ngram
|
||||||
|
prev_lps = lps[max_ngram - 1]
|
||||||
i += 1
|
i += 1
|
||||||
|
elif prev_lps != 0:
|
||||||
|
# Token mismatch: try the second longest prefix
|
||||||
|
# among all suffix of tokens[:i],
|
||||||
|
# which is the longest prefix of tokens[:prev_lps]
|
||||||
|
prev_lps = lps[prev_lps - 1]
|
||||||
else:
|
else:
|
||||||
if prev_lps != 0:
|
# Token mismatch, and no more prefix (except empty string)
|
||||||
prev_lps = lps[prev_lps - 1]
|
# as a suffix of tokens[:i]
|
||||||
else:
|
|
||||||
lps[i] = 0
|
|
||||||
i += 1
|
|
||||||
return lps
|
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _find_subarray_kmp(
|
|
||||||
context_token_ids: np.ndarray,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
) -> Optional[np.ndarray]:
|
|
||||||
context_len = context_token_ids.shape[0]
|
|
||||||
assert n > 0
|
|
||||||
|
|
||||||
pattern = context_token_ids[-n:]
|
|
||||||
# Precompute lps array for Y
|
|
||||||
lps = _kmp_lps_array(pattern)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
j = 0
|
|
||||||
# -n because the last n tokens are used as pattern
|
|
||||||
while i < context_len - n:
|
|
||||||
if context_token_ids[i] == pattern[j]:
|
|
||||||
i += 1
|
i += 1
|
||||||
j += 1
|
|
||||||
|
|
||||||
# If we have matched the entire Y
|
if longest_ngram < min_ngram:
|
||||||
if j == n:
|
# No valid ngram is found
|
||||||
# Found pattern in context, gather the next K elements
|
return None
|
||||||
return context_token_ids[i:i + k]
|
|
||||||
else:
|
|
||||||
# Mismatch
|
|
||||||
if j != 0:
|
|
||||||
# Use the lps array to avoid re-checking elements
|
|
||||||
j = lps[j - 1]
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# Y not found
|
# Flip the position back, so in origin_tokens,
|
||||||
return None
|
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
|
||||||
|
# is the matched ngram, so we should start drafting tokens from
|
||||||
|
# total_token-1-position+longest_ngram
|
||||||
|
start_position = total_token - 1 - position + longest_ngram
|
||||||
|
k = min(k, total_token - start_position)
|
||||||
|
return origin_tokens[start_position:start_position + k]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user