mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-17 23:54:29 +08:00
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c8d70e2437
commit
4c82229898
@ -1,6 +1,7 @@
|
||||
psutil
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy < 2.0.0
|
||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
|
||||
requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
|
||||
class NgramProposer:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def propose(
|
||||
self,
|
||||
context_token_ids: np.ndarray,
|
||||
@ -21,7 +19,7 @@ class NgramProposer:
|
||||
that match.
|
||||
|
||||
Args:
|
||||
context_token_ids: List of token IDs representing the
|
||||
context_token_ids: Numpy array of token IDs representing the
|
||||
context sequence.
|
||||
n: Length of the n-gram to match.
|
||||
k: Number of tokens follow the match. If there are less
|
||||
@ -41,66 +39,65 @@ class NgramProposer:
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# TODO: Use c++ to implement the _find_subarray_kmp to
|
||||
# improve the efficiency
|
||||
return self._find_subarray_kmp(context_token_ids, n, k)
|
||||
return _find_subarray_kmp(context_token_ids, n, k)
|
||||
|
||||
@staticmethod
|
||||
def _kmp_lps_array(pattern: List[int]) -> List[int]:
|
||||
"""
|
||||
Build the lps (longest proper prefix which is also suffix)
|
||||
array for the pattern.
|
||||
"""
|
||||
lps = [0] * len(pattern)
|
||||
prev_lps = 0 # length of the previous longest prefix suffix
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[prev_lps]:
|
||||
prev_lps += 1
|
||||
lps[i] = prev_lps
|
||||
i += 1
|
||||
@jit(nopython=True)
|
||||
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Build the lps (longest proper prefix which is also suffix)
|
||||
array for the pattern.
|
||||
"""
|
||||
lps = np.zeros(len(pattern), dtype=np.int32)
|
||||
prev_lps = 0 # length of the previous longest prefix suffix
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[prev_lps]:
|
||||
prev_lps += 1
|
||||
lps[i] = prev_lps
|
||||
i += 1
|
||||
else:
|
||||
if prev_lps != 0:
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
if prev_lps != 0:
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
|
||||
return lps
|
||||
|
||||
@staticmethod
|
||||
def _find_subarray_kmp(
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
context_len = context_token_ids.shape[0]
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
# Precompute lps array for Y
|
||||
lps = NgramProposer._kmp_lps_array(pattern)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
# -n because the last n tokens are used as pattern
|
||||
while i < context_len - n:
|
||||
if context_token_ids[i] == pattern[j]:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
j += 1
|
||||
return lps
|
||||
|
||||
# If we have matched the entire Y
|
||||
if j == n:
|
||||
# Found pattern in context, gather the next K elements
|
||||
return context_token_ids[i:i + k]
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_subarray_kmp(
|
||||
context_token_ids: np.ndarray,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
context_len = context_token_ids.shape[0]
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
# Precompute lps array for Y
|
||||
lps = _kmp_lps_array(pattern)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
# -n because the last n tokens are used as pattern
|
||||
while i < context_len - n:
|
||||
if context_token_ids[i] == pattern[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# If we have matched the entire Y
|
||||
if j == n:
|
||||
# Found pattern in context, gather the next K elements
|
||||
return context_token_ids[i:i + k]
|
||||
else:
|
||||
# Mismatch
|
||||
if j != 0:
|
||||
# Use the lps array to avoid re-checking elements
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
# Mismatch
|
||||
if j != 0:
|
||||
# Use the lps array to avoid re-checking elements
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
i += 1
|
||||
i += 1
|
||||
|
||||
# Y not found
|
||||
return None
|
||||
# Y not found
|
||||
return None
|
||||
|
||||
@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
self.drafter = NgramProposer()
|
||||
self.use_spec_decode = True
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = NgramProposer()
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.drafter.propose(
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user