mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Core] [N-gram SD Optimization][1/n] Propose tokens with a single KMP (#22437)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
4e8614e88b
commit
31a500c86f
74
benchmarks/benchmark_block_pool.py
Normal file
74
benchmarks/benchmark_block_pool.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_times = TimeCollector(TimeCollector.US)
|
||||
free_blocks_times = TimeCollector(TimeCollector.US)
|
||||
for _ in range(args.num_iteration):
|
||||
with get_blocks_times:
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
with free_blocks_times:
|
||||
block_pool.free_blocks(blocks)
|
||||
|
||||
rows.append(
|
||||
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
|
||||
+ get_blocks_times.dump_avg_max()
|
||||
+ free_blocks_times.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (us)",
|
||||
"Get Blocks\nMax (us)",
|
||||
"Free Blocks\nAvg (us)",
|
||||
"Free Blocks\nMax (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
112
benchmarks/benchmark_ngram_proposer.py
Normal file
112
benchmarks/benchmark_ngram_proposer.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for max_ngram in args.max_ngram:
|
||||
collector = TimeCollector(TimeCollector.US)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=args.num_token + args.num_spec_token,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
vllm_config=VllmConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=args.min_ngram,
|
||||
prompt_lookup_max=max_ngram,
|
||||
num_speculative_tokens=args.num_spec_token,
|
||||
method="ngram",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Warm up
|
||||
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
||||
|
||||
gc.collect()
|
||||
for _ in range(args.num_iteration):
|
||||
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
||||
with collector:
|
||||
for i in range(args.num_req):
|
||||
proposer.propose(tokens[i, :])
|
||||
rows.append(
|
||||
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
||||
+ collector.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"# Request",
|
||||
"# Token",
|
||||
"Min Ngram",
|
||||
"Max Ngram",
|
||||
"Avg (us)",
|
||||
"Max (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Minimum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-ngram",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[5, 7, 10, 15, 20],
|
||||
help="Maximum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-spec-token",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of speculative tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None:
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
||||
)
|
||||
|
||||
|
||||
# Collect time and generate time metrics
|
||||
#
|
||||
# Example Usage:
|
||||
# collector = TimeCollector(TimeCollector.US)
|
||||
# for _ in range(total_iteration):
|
||||
# with collector:
|
||||
# ...
|
||||
# collector.dump_avg_max()
|
||||
class TimeCollector:
|
||||
NS: int = 1
|
||||
US: int = NS * 1000
|
||||
MS: int = US * 1000
|
||||
S: int = MS * 1000
|
||||
|
||||
def __init__(self, scale: int) -> None:
|
||||
self.cnt: int = 0
|
||||
self._sum: int = 0
|
||||
self._max: Optional[int] = None
|
||||
self.scale = scale
|
||||
self.start_time: int = time.monotonic_ns()
|
||||
|
||||
def collect(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self._sum += v
|
||||
if self._max is None:
|
||||
self._max = v
|
||||
else:
|
||||
self._max = max(self._max, v)
|
||||
|
||||
def avg(self) -> Union[float, str]:
|
||||
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
|
||||
|
||||
def max(self) -> Union[float, str]:
|
||||
return self._max / self.scale if self._max else "N/A"
|
||||
|
||||
def dump_avg_max(self) -> list[Union[float, str]]:
|
||||
return [self.avg(), self.max()]
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.start_time = time.monotonic_ns()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.collect(time.monotonic_ns() - self.start_time)
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self) -> None:
|
||||
self.cnt: int = 0
|
||||
self.sum_v: int = 0
|
||||
self.max_v: Optional[int] = None
|
||||
|
||||
def update(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self.sum_v += v
|
||||
if self.max_v is None:
|
||||
self.max_v = v
|
||||
else:
|
||||
self.max_v = max(self.max_v, v)
|
||||
|
||||
def avg_v(self) -> float:
|
||||
return self.sum_v * 1.0 / self.cnt
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_metric: Metric = Metric()
|
||||
free_blocks_metric: Metric = Metric()
|
||||
for _ in range(args.num_iteration):
|
||||
t1 = time.monotonic_ns()
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
t2 = time.monotonic_ns()
|
||||
block_pool.free_blocks(blocks)
|
||||
t3 = time.monotonic_ns()
|
||||
get_blocks_metric.update(t2 - t1)
|
||||
free_blocks_metric.update(t3 - t2)
|
||||
|
||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
||||
rows.append(
|
||||
[
|
||||
get_blocks_metric.cnt,
|
||||
args.num_gpu_blocks,
|
||||
allocate_block,
|
||||
get_blocks_metric.avg_v() / 1000000,
|
||||
get_blocks_metric.max_v / 1000000.0,
|
||||
free_blocks_metric.avg_v() / 1000000,
|
||||
free_blocks_metric.max_v / 1000000.0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"No valid metrics found."
|
||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (ms)",
|
||||
"Get Blocks\nMax (ms)",
|
||||
"Free Blocks\nAvg (ms)",
|
||||
"Free Blocks\nMax (ms)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".6f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -1,43 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
||||
_find_subarray_kmp,
|
||||
_kmp_lps_array)
|
||||
from vllm.v1.spec_decode.ngram_proposer import (
|
||||
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
|
||||
|
||||
|
||||
def test_kmp_lps_array():
|
||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([]))
|
||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0]))
|
||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])),
|
||||
np.array([0, 1, 2]))
|
||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])),
|
||||
np.array([0, 0, 0, 0]))
|
||||
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])),
|
||||
np.array([0, 0, 1, 2, 0]))
|
||||
def test_find_longest_matched_ngram_and_propose_tokens():
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2) is None
|
||||
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([4, 1]))
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([4, 1]))
|
||||
|
||||
def test_find_subarray_kmp():
|
||||
X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
assert _find_subarray_kmp(X, 2, 2) is None
|
||||
X = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
|
||||
np.array([4, 1, 2]))
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4,
|
||||
1]))
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
||||
np.array([4, 1, 2]))
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4,
|
||||
1]))
|
||||
X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
|
||||
np.array([4, 1, 2]))
|
||||
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=3),
|
||||
np.array([4, 1, 2]))
|
||||
# Return on the first match
|
||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
||||
np.array([6, 2, 3]))
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=1,
|
||||
max_ngram=1,
|
||||
max_model_len=1024,
|
||||
k=2), np.array([6, 2]))
|
||||
|
||||
|
||||
def test_ngram_proposer():
|
||||
@ -56,27 +76,35 @@ def test_ngram_proposer():
|
||||
|
||||
# No match.
|
||||
result = ngram_proposer(
|
||||
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||
min_n=2, max_n=2,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||
assert result is None
|
||||
|
||||
# No match for 4-gram.
|
||||
result = ngram_proposer(
|
||||
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
min_n=4, max_n=4,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert result is None
|
||||
|
||||
# No match for 4-gram but match for 3-gram.
|
||||
result = ngram_proposer(
|
||||
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
min_n=3, max_n=4,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert np.array_equal(result, np.array([4, 1]))
|
||||
|
||||
# Match for both 4-gram and 3-gram.
|
||||
# In this case, the proposer should return the 4-gram match.
|
||||
result = ngram_proposer(3, 4, 2).propose(
|
||||
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||
|
||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||
result = ngram_proposer(
|
||||
2, 4,
|
||||
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||
|
||||
# Multiple 3-gram matched, but always pick the first one.
|
||||
result = ngram_proposer(
|
||||
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
|
||||
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
|
||||
assert np.array_equal(result, np.array([100, 1]))
|
||||
|
||||
@ -11,6 +11,10 @@ from vllm.config import VllmConfig
|
||||
class NgramProposer:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.speculative_config is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
||||
|
||||
# Minimum length of the n-gram to match.
|
||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||
# Maximum length of the n-gram to match.
|
||||
@ -54,17 +58,13 @@ class NgramProposer:
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
|
||||
if k <= 0:
|
||||
return None
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
for n in range(self.max_n, self.min_n - 1, -1):
|
||||
result = _find_subarray_kmp(context_token_ids, n, k)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
return _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=self.min_n,
|
||||
max_ngram=self.max_n,
|
||||
max_model_len=self.max_model_len,
|
||||
k=self.k)
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
@ -72,61 +72,86 @@ class NgramProposer:
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
|
||||
def _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
|
||||
max_model_len: int, k: int) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Build the lps (longest proper prefix which is also suffix)
|
||||
array for the pattern.
|
||||
Find the longest n-gram which matches the suffix of the given tokens
|
||||
whose length is within [min_ngram, max_ngram] (inclusive).
|
||||
|
||||
If found, we will extract k right after the matched ngram.
|
||||
"""
|
||||
lps = np.zeros(len(pattern), dtype=np.int32)
|
||||
prev_lps = 0 # length of the previous longest prefix suffix
|
||||
# Do not generate draft tokens is context is shorter than minimum n-gram
|
||||
total_token = origin_tokens.shape[0]
|
||||
if total_token < min_ngram:
|
||||
return None
|
||||
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(k, max_model_len - total_token)
|
||||
if k <= 0:
|
||||
return None
|
||||
|
||||
# Flip tokens, and the goal become to find longest ngram
|
||||
# on the rightmost position which matches the prefix with
|
||||
# length [min_n, max_n] (inclusive).
|
||||
tokens = origin_tokens[::-1]
|
||||
|
||||
# Longest prefix (not including itself) which is a suffix of
|
||||
# the current position.
|
||||
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
|
||||
#
|
||||
# As ngram is capped by max_ngram to save memory, we only need to
|
||||
# store lps for the first max_ngram prefix.
|
||||
lps = np.zeros(max_ngram, dtype=np.int32)
|
||||
|
||||
longest_ngram = 0
|
||||
position = 0
|
||||
|
||||
# lps[0] always equal to 0, we starts with index 1
|
||||
prev_lps = 0
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[prev_lps]:
|
||||
while i < total_token:
|
||||
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
|
||||
if tokens[prev_lps] == tokens[i]:
|
||||
# Token match: tokens[:prev_lps+1] is the longest prefix as
|
||||
# a suffix of tokens[:i+1]
|
||||
prev_lps += 1
|
||||
lps[i] = prev_lps
|
||||
# Check if we found a longer valid ngram.
|
||||
#
|
||||
# Update position when longest_ngram matched prev_lps,
|
||||
# as we want to get the target n-gram of the earliest position
|
||||
# in the original tokens (i.e.
|
||||
# latest position in the reversed tokens)
|
||||
if prev_lps >= longest_ngram:
|
||||
longest_ngram = prev_lps
|
||||
position = i
|
||||
if i < max_ngram:
|
||||
# Store LPS for the first max_ngram prefix
|
||||
lps[i] = prev_lps
|
||||
if prev_lps == max_ngram:
|
||||
# When prev_lps reached max_ngram, update prev_lps
|
||||
# to lps[max_ngram-1] to avoid matching ngram
|
||||
# longer than max_ngram
|
||||
prev_lps = lps[max_ngram - 1]
|
||||
i += 1
|
||||
elif prev_lps != 0:
|
||||
# Token mismatch: try the second longest prefix
|
||||
# among all suffix of tokens[:i],
|
||||
# which is the longest prefix of tokens[:prev_lps]
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
if prev_lps != 0:
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
return lps
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_subarray_kmp(
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
context_len = context_token_ids.shape[0]
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
# Precompute lps array for Y
|
||||
lps = _kmp_lps_array(pattern)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
# -n because the last n tokens are used as pattern
|
||||
while i < context_len - n:
|
||||
if context_token_ids[i] == pattern[j]:
|
||||
# Token mismatch, and no more prefix (except empty string)
|
||||
# as a suffix of tokens[:i]
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# If we have matched the entire Y
|
||||
if j == n:
|
||||
# Found pattern in context, gather the next K elements
|
||||
return context_token_ids[i:i + k]
|
||||
else:
|
||||
# Mismatch
|
||||
if j != 0:
|
||||
# Use the lps array to avoid re-checking elements
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
i += 1
|
||||
if longest_ngram < min_ngram:
|
||||
# No valid ngram is found
|
||||
return None
|
||||
|
||||
# Y not found
|
||||
return None
|
||||
# Flip the position back, so in origin_tokens,
|
||||
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
|
||||
# is the matched ngram, so we should start drafting tokens from
|
||||
# total_token-1-position+longest_ngram
|
||||
start_position = total_token - 1 - position + longest_ngram
|
||||
k = min(k, total_token - start_position)
|
||||
return origin_tokens[start_position:start_position + k]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user