[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-18 13:19:58 -08:00 committed by GitHub
parent c8d70e2437
commit 4c82229898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 63 deletions

View File

@ -1,6 +1,7 @@
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0
tqdm
blake3

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional
import numpy as np
from numba import jit
class NgramProposer:
def __init__(self):
pass
def propose(
self,
context_token_ids: np.ndarray,
@ -21,7 +19,7 @@ class NgramProposer:
that match.
Args:
context_token_ids: List of token IDs representing the
context_token_ids: Numpy array of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
@ -41,66 +39,65 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
return _find_subarray_kmp(context_token_ids, n, k)
@staticmethod
def _kmp_lps_array(pattern: List[int]) -> List[int]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = [0] * len(pattern)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
return lps
@staticmethod
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern)
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
lps[i] = 0
i += 1
j += 1
return lps
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
@jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = _kmp_lps_array(pattern)
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
i += 1
# Y not found
return None
# Y not found
return None

View File

@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer()
self.use_spec_decode = True
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}