diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py new file mode 100644 index 0000000000000..fd363c2ad0514 --- /dev/null +++ b/benchmarks/benchmark_block_pool.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc + +from tabulate import tabulate + +from benchmark_utils import TimeCollector +from vllm.utils import FlexibleArgumentParser +from vllm.v1.core.block_pool import BlockPool + + +def main(args): + rows = [] + for allocate_block in args.allocate_blocks: + # Enforce a GC collect ahead to minimize the impact among runs + gc.collect() + block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) + + get_blocks_times = TimeCollector(TimeCollector.US) + free_blocks_times = TimeCollector(TimeCollector.US) + for _ in range(args.num_iteration): + with get_blocks_times: + blocks = block_pool.get_new_blocks(allocate_block) + with free_blocks_times: + block_pool.free_blocks(blocks) + + rows.append( + [get_blocks_times.cnt, args.num_gpu_blocks, allocate_block] + + get_blocks_times.dump_avg_max() + + free_blocks_times.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "Iterations", + "Total\nBlocks", + "Allocated\nBlocks", + "Get Blocks\nAvg (us)", + "Get Blocks\nMax (us)", + "Free Blocks\nAvg (us)", + "Free Blocks\nMax (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of BlockPool for KV Cache." + ) + parser.add_argument("--num-gpu-blocks", type=int, default=100000) + parser.add_argument( + "--num-iteration", + type=int, + default=1000, + help="Number of iterations to run to stablize final data readings", + ) + parser.add_argument( + "--allocate-blocks", + type=int, + nargs="*", + default=[10, 50, 100, 500, 1000], + help="Number of blocks to allocate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py new file mode 100644 index 0000000000000..c60040d05ab7a --- /dev/null +++ b/benchmarks/benchmark_ngram_proposer.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc + +import numpy as np +from tabulate import tabulate + +from benchmark_utils import TimeCollector +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.utils import FlexibleArgumentParser +from vllm.v1.spec_decode.ngram_proposer import NgramProposer + + +def main(args): + rows = [] + for max_ngram in args.max_ngram: + collector = TimeCollector(TimeCollector.US) + + model_config = ModelConfig( + model="facebook/opt-125m", + task="generate", + max_model_len=args.num_token + args.num_spec_token, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False, + ) + proposer = NgramProposer( + vllm_config=VllmConfig( + model_config=model_config, + speculative_config=SpeculativeConfig( + prompt_lookup_min=args.min_ngram, + prompt_lookup_max=max_ngram, + num_speculative_tokens=args.num_spec_token, + method="ngram", + ), + ) + ) + + # Warm up + proposer.propose(np.random.randint(0, 20, (args.num_token,))) + + gc.collect() + for _ in range(args.num_iteration): + tokens = np.random.randint(0, 20, (args.num_req, args.num_token)) + with collector: + for i in range(args.num_req): + proposer.propose(tokens[i, :]) + rows.append( + [args.num_req, args.num_token, args.min_ngram, max_ngram] + + collector.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "# Request", + "# Token", + "Min Ngram", + "Max Ngram", + "Avg (us)", + "Max (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of N-gram speculative decode drafting" + ) + parser.add_argument( + "--num-iteration", + type=int, + default=100, + help="Number of iterations to run to stablize final data readings", + ) + parser.add_argument( + "--num-req", type=int, default=128, help="Number of requests in the batch" + ) + parser.add_argument( + "--num-token", type=int, default=1500, help="Number of tokens for each request" + ) + parser.add_argument( + "--min-ngram", + type=int, + default=3, + help="Minimum n-gram to match", + ) + parser.add_argument( + "--max-ngram", + type=int, + nargs="*", + default=[5, 7, 10, 15, 20], + help="Maximum n-gram to match", + ) + parser.add_argument( + "--num-spec-token", + type=int, + default=3, + help="Number of speculative tokens to generate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 283f938df50af..98624abdf49fb 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import argparse import json import math import os -from typing import Any +import time +from types import TracebackType +from typing import Any, Optional, Union def convert_to_pytorch_benchmark_format( @@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None: cls=InfEncoder, default=lambda o: f"<{type(o).__name__} object is not JSON serializable>", ) + + +# Collect time and generate time metrics +# +# Example Usage: +# collector = TimeCollector(TimeCollector.US) +# for _ in range(total_iteration): +# with collector: +# ... +# collector.dump_avg_max() +class TimeCollector: + NS: int = 1 + US: int = NS * 1000 + MS: int = US * 1000 + S: int = MS * 1000 + + def __init__(self, scale: int) -> None: + self.cnt: int = 0 + self._sum: int = 0 + self._max: Optional[int] = None + self.scale = scale + self.start_time: int = time.monotonic_ns() + + def collect(self, v: int) -> None: + self.cnt += 1 + self._sum += v + if self._max is None: + self._max = v + else: + self._max = max(self._max, v) + + def avg(self) -> Union[float, str]: + return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A" + + def max(self) -> Union[float, str]: + return self._max / self.scale if self._max else "N/A" + + def dump_avg_max(self) -> list[Union[float, str]]: + return [self.avg(), self.max()] + + def __enter__(self) -> None: + self.start_time = time.monotonic_ns() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + self.collect(time.monotonic_ns() - self.start_time) diff --git a/benchmarks/kv_cache/benchmark_block_pool.py b/benchmarks/kv_cache/benchmark_block_pool.py deleted file mode 100644 index 134551bb61285..0000000000000 --- a/benchmarks/kv_cache/benchmark_block_pool.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc -import time -from typing import Optional - -from tabulate import tabulate - -from vllm.utils import FlexibleArgumentParser -from vllm.v1.core.block_pool import BlockPool - - -class Metric: - def __init__(self) -> None: - self.cnt: int = 0 - self.sum_v: int = 0 - self.max_v: Optional[int] = None - - def update(self, v: int) -> None: - self.cnt += 1 - self.sum_v += v - if self.max_v is None: - self.max_v = v - else: - self.max_v = max(self.max_v, v) - - def avg_v(self) -> float: - return self.sum_v * 1.0 / self.cnt - - -def main(args): - rows = [] - for allocate_block in args.allocate_blocks: - # Enforce a GC collect ahead to minimize the impact among runs - gc.collect() - block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) - - get_blocks_metric: Metric = Metric() - free_blocks_metric: Metric = Metric() - for _ in range(args.num_iteration): - t1 = time.monotonic_ns() - blocks = block_pool.get_new_blocks(allocate_block) - t2 = time.monotonic_ns() - block_pool.free_blocks(blocks) - t3 = time.monotonic_ns() - get_blocks_metric.update(t2 - t1) - free_blocks_metric.update(t3 - t2) - - if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None: - rows.append( - [ - get_blocks_metric.cnt, - args.num_gpu_blocks, - allocate_block, - get_blocks_metric.avg_v() / 1000000, - get_blocks_metric.max_v / 1000000.0, - free_blocks_metric.avg_v() / 1000000, - free_blocks_metric.max_v / 1000000.0, - ] - ) - else: - print( - "No valid metrics found." - f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}" - ) - - print( - tabulate( - rows, - headers=[ - "Iterations", - "Total\nBlocks", - "Allocated\nBlocks", - "Get Blocks\nAvg (ms)", - "Get Blocks\nMax (ms)", - "Free Blocks\nAvg (ms)", - "Free Blocks\nMax (ms)", - ], - tablefmt="grid", - floatfmt=".6f", - ) - ) - - -def invoke_main() -> None: - parser = FlexibleArgumentParser( - description="Benchmark the performance of BlockPool for KV Cache." - ) - parser.add_argument("--num-gpu-blocks", type=int, default=100000) - parser.add_argument( - "--num-iteration", - type=int, - default=1000, - help="Number of iterations to run to stablize final data readings", - ) - parser.add_argument( - "--allocate-blocks", - type=int, - nargs="*", - default=[10, 50, 100, 500, 1000], - help="Number of blocks to allocate", - ) - args = parser.parse_args() - main(args) - - -if __name__ == "__main__": - invoke_main() # pragma: no cover diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index b7303e0443d32..4193f4041b32b 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -1,43 +1,63 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import numpy as np from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig -from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, - _find_subarray_kmp, - _kmp_lps_array) +from vllm.v1.spec_decode.ngram_proposer import ( + NgramProposer, _find_longest_matched_ngram_and_propose_tokens) -def test_kmp_lps_array(): - np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])), - np.array([0, 1, 2])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])), - np.array([0, 0, 0, 0])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])), - np.array([0, 0, 1, 2, 0])) +def test_find_longest_matched_ngram_and_propose_tokens(): + tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) + assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2) is None + tokens = np.array([1, 2, 3, 4, 1, 2, 3]) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2), np.array([4, 1])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=2), np.array([4, 1])) -def test_find_subarray_kmp(): - X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert _find_subarray_kmp(X, 2, 2) is None - X = np.array([1, 2, 3, 4, 1, 2, 3]) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3), - np.array([4, 1, 2])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4, - 1])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), - np.array([4, 1, 2])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4, - 1])) - X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3), - np.array([4, 1, 2])) + tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) # Return on the first match - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), - np.array([6, 2, 3])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=2), np.array([6, 2])) def test_ngram_proposer(): @@ -56,27 +76,35 @@ def test_ngram_proposer(): # No match. result = ngram_proposer( - 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) + min_n=2, max_n=2, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) assert result is None # No match for 4-gram. result = ngram_proposer( - 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) + min_n=4, max_n=4, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert result is None # No match for 4-gram but match for 3-gram. result = ngram_proposer( - 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) + min_n=3, max_n=4, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert np.array_equal(result, np.array([4, 1])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = ngram_proposer(3, 4, 2).propose( + result = ngram_proposer(min_n=3, max_n=4, k=2).propose( context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] # Match for 2-gram and 3-gram, but not 4-gram. - result = ngram_proposer( - 2, 4, - 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) + result = ngram_proposer(min_n=2, max_n=4, k=2).propose( + context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] + + # Multiple 3-gram matched, but always pick the first one. + result = ngram_proposer( + min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( + [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) + assert np.array_equal(result, np.array([100, 1])) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 6b90d0970bd77..fbcf2cb50d371 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -11,6 +11,10 @@ from vllm.config import VllmConfig class NgramProposer: def __init__(self, vllm_config: VllmConfig): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + # Minimum length of the n-gram to match. self.min_n = vllm_config.speculative_config.prompt_lookup_min # Maximum length of the n-gram to match. @@ -54,17 +58,13 @@ class NgramProposer: followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ - # Do not generate draft tokens beyond the max model length. - k = min(self.k, self.max_model_len - context_token_ids.shape[0]) - if k <= 0: - return None - # TODO(woosuk): Optimize this. - for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, k) - if result is not None: - return result - return None + return _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=context_token_ids, + min_ngram=self.min_n, + max_ngram=self.max_n, + max_model_len=self.max_model_len, + k=self.k) def load_model(self, *args, **kwargs): # No model to load. @@ -72,61 +72,86 @@ class NgramProposer: @jit(nopython=True) -def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: +def _find_longest_matched_ngram_and_propose_tokens( + origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, + max_model_len: int, k: int) -> Optional[np.ndarray]: """ - Build the lps (longest proper prefix which is also suffix) - array for the pattern. + Find the longest n-gram which matches the suffix of the given tokens + whose length is within [min_ngram, max_ngram] (inclusive). + + If found, we will extract k right after the matched ngram. """ - lps = np.zeros(len(pattern), dtype=np.int32) - prev_lps = 0 # length of the previous longest prefix suffix + # Do not generate draft tokens is context is shorter than minimum n-gram + total_token = origin_tokens.shape[0] + if total_token < min_ngram: + return None + + # Do not generate draft tokens beyond the max model length. + k = min(k, max_model_len - total_token) + if k <= 0: + return None + + # Flip tokens, and the goal become to find longest ngram + # on the rightmost position which matches the prefix with + # length [min_n, max_n] (inclusive). + tokens = origin_tokens[::-1] + + # Longest prefix (not including itself) which is a suffix of + # the current position. + # lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]} + # + # As ngram is capped by max_ngram to save memory, we only need to + # store lps for the first max_ngram prefix. + lps = np.zeros(max_ngram, dtype=np.int32) + + longest_ngram = 0 + position = 0 + + # lps[0] always equal to 0, we starts with index 1 + prev_lps = 0 i = 1 - - while i < len(pattern): - if pattern[i] == pattern[prev_lps]: + while i < total_token: + # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i] + if tokens[prev_lps] == tokens[i]: + # Token match: tokens[:prev_lps+1] is the longest prefix as + # a suffix of tokens[:i+1] prev_lps += 1 - lps[i] = prev_lps + # Check if we found a longer valid ngram. + # + # Update position when longest_ngram matched prev_lps, + # as we want to get the target n-gram of the earliest position + # in the original tokens (i.e. + # latest position in the reversed tokens) + if prev_lps >= longest_ngram: + longest_ngram = prev_lps + position = i + if i < max_ngram: + # Store LPS for the first max_ngram prefix + lps[i] = prev_lps + if prev_lps == max_ngram: + # When prev_lps reached max_ngram, update prev_lps + # to lps[max_ngram-1] to avoid matching ngram + # longer than max_ngram + prev_lps = lps[max_ngram - 1] i += 1 + elif prev_lps != 0: + # Token mismatch: try the second longest prefix + # among all suffix of tokens[:i], + # which is the longest prefix of tokens[:prev_lps] + prev_lps = lps[prev_lps - 1] else: - if prev_lps != 0: - prev_lps = lps[prev_lps - 1] - else: - lps[i] = 0 - i += 1 - return lps - - -@jit(nopython=True) -def _find_subarray_kmp( - context_token_ids: np.ndarray, - n: int, - k: int, -) -> Optional[np.ndarray]: - context_len = context_token_ids.shape[0] - assert n > 0 - - pattern = context_token_ids[-n:] - # Precompute lps array for Y - lps = _kmp_lps_array(pattern) - - i = 0 - j = 0 - # -n because the last n tokens are used as pattern - while i < context_len - n: - if context_token_ids[i] == pattern[j]: + # Token mismatch, and no more prefix (except empty string) + # as a suffix of tokens[:i] i += 1 - j += 1 - # If we have matched the entire Y - if j == n: - # Found pattern in context, gather the next K elements - return context_token_ids[i:i + k] - else: - # Mismatch - if j != 0: - # Use the lps array to avoid re-checking elements - j = lps[j - 1] - else: - i += 1 + if longest_ngram < min_ngram: + # No valid ngram is found + return None - # Y not found - return None + # Flip the position back, so in origin_tokens, + # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] + # is the matched ngram, so we should start drafting tokens from + # total_token-1-position+longest_ngram + start_position = total_token - 1 - position + longest_ngram + k = min(k, total_token - start_position) + return origin_tokens[start_position:start_position + k]