mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 05:11:18 +08:00
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c8d70e2437
commit
4c82229898
@ -1,6 +1,7 @@
|
|||||||
psutil
|
psutil
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy < 2.0.0
|
numpy < 2.0.0
|
||||||
|
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
|
||||||
requests >= 2.26.0
|
requests >= 2.26.0
|
||||||
tqdm
|
tqdm
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
@ -1,14 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numba import jit
|
||||||
|
|
||||||
|
|
||||||
class NgramProposer:
|
class NgramProposer:
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
context_token_ids: np.ndarray,
|
context_token_ids: np.ndarray,
|
||||||
@ -21,7 +19,7 @@ class NgramProposer:
|
|||||||
that match.
|
that match.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context_token_ids: List of token IDs representing the
|
context_token_ids: Numpy array of token IDs representing the
|
||||||
context sequence.
|
context sequence.
|
||||||
n: Length of the n-gram to match.
|
n: Length of the n-gram to match.
|
||||||
k: Number of tokens follow the match. If there are less
|
k: Number of tokens follow the match. If there are less
|
||||||
@ -41,66 +39,65 @@ class NgramProposer:
|
|||||||
followed that pattern. Here we will return [4,2,3] because
|
followed that pattern. Here we will return [4,2,3] because
|
||||||
we only have three tokens after the match.
|
we only have three tokens after the match.
|
||||||
"""
|
"""
|
||||||
# TODO: Use c++ to implement the _find_subarray_kmp to
|
return _find_subarray_kmp(context_token_ids, n, k)
|
||||||
# improve the efficiency
|
|
||||||
return self._find_subarray_kmp(context_token_ids, n, k)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _kmp_lps_array(pattern: List[int]) -> List[int]:
|
|
||||||
"""
|
|
||||||
Build the lps (longest proper prefix which is also suffix)
|
|
||||||
array for the pattern.
|
|
||||||
"""
|
|
||||||
lps = [0] * len(pattern)
|
|
||||||
prev_lps = 0 # length of the previous longest prefix suffix
|
|
||||||
i = 1
|
|
||||||
|
|
||||||
while i < len(pattern):
|
@jit(nopython=True)
|
||||||
if pattern[i] == pattern[prev_lps]:
|
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
|
||||||
prev_lps += 1
|
"""
|
||||||
lps[i] = prev_lps
|
Build the lps (longest proper prefix which is also suffix)
|
||||||
i += 1
|
array for the pattern.
|
||||||
|
"""
|
||||||
|
lps = np.zeros(len(pattern), dtype=np.int32)
|
||||||
|
prev_lps = 0 # length of the previous longest prefix suffix
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
while i < len(pattern):
|
||||||
|
if pattern[i] == pattern[prev_lps]:
|
||||||
|
prev_lps += 1
|
||||||
|
lps[i] = prev_lps
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
if prev_lps != 0:
|
||||||
|
prev_lps = lps[prev_lps - 1]
|
||||||
else:
|
else:
|
||||||
if prev_lps != 0:
|
lps[i] = 0
|
||||||
prev_lps = lps[prev_lps - 1]
|
|
||||||
else:
|
|
||||||
lps[i] = 0
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return lps
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_subarray_kmp(
|
|
||||||
context_token_ids: np.ndarray,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
) -> Optional[np.ndarray]:
|
|
||||||
context_len = context_token_ids.shape[0]
|
|
||||||
assert n > 0
|
|
||||||
|
|
||||||
pattern = context_token_ids[-n:]
|
|
||||||
# Precompute lps array for Y
|
|
||||||
lps = NgramProposer._kmp_lps_array(pattern)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
j = 0
|
|
||||||
# -n because the last n tokens are used as pattern
|
|
||||||
while i < context_len - n:
|
|
||||||
if context_token_ids[i] == pattern[j]:
|
|
||||||
i += 1
|
i += 1
|
||||||
j += 1
|
return lps
|
||||||
|
|
||||||
# If we have matched the entire Y
|
|
||||||
if j == n:
|
@jit(nopython=True)
|
||||||
# Found pattern in context, gather the next K elements
|
def _find_subarray_kmp(
|
||||||
return context_token_ids[i:i + k]
|
context_token_ids: np.ndarray,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
context_len = context_token_ids.shape[0]
|
||||||
|
assert n > 0
|
||||||
|
|
||||||
|
pattern = context_token_ids[-n:]
|
||||||
|
# Precompute lps array for Y
|
||||||
|
lps = _kmp_lps_array(pattern)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
j = 0
|
||||||
|
# -n because the last n tokens are used as pattern
|
||||||
|
while i < context_len - n:
|
||||||
|
if context_token_ids[i] == pattern[j]:
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# If we have matched the entire Y
|
||||||
|
if j == n:
|
||||||
|
# Found pattern in context, gather the next K elements
|
||||||
|
return context_token_ids[i:i + k]
|
||||||
|
else:
|
||||||
|
# Mismatch
|
||||||
|
if j != 0:
|
||||||
|
# Use the lps array to avoid re-checking elements
|
||||||
|
j = lps[j - 1]
|
||||||
else:
|
else:
|
||||||
# Mismatch
|
i += 1
|
||||||
if j != 0:
|
|
||||||
# Use the lps array to avoid re-checking elements
|
|
||||||
j = lps[j - 1]
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# Y not found
|
# Y not found
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
self.use_spec_decode = False
|
self.use_spec_decode = False
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
|
self.use_spec_decode = True
|
||||||
|
|
||||||
# TODO: find a better way to check if we are using ngram.
|
# TODO: find a better way to check if we are using ngram.
|
||||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||||
"Currently, only ngram spec decode is supported in V1."
|
"Currently, only ngram spec decode is supported in V1."
|
||||||
self.drafter = NgramProposer()
|
if get_pp_group().is_last_rank:
|
||||||
self.use_spec_decode = True
|
self.drafter = NgramProposer()
|
||||||
|
# Trigger Numba JIT compilation for N-gram proposer.
|
||||||
|
# This usually takes less than 1 second.
|
||||||
|
self.drafter.propose(
|
||||||
|
np.zeros(1024, dtype=np.int32),
|
||||||
|
self.speculative_config.ngram_prompt_lookup_min,
|
||||||
|
self.speculative_config.num_speculative_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user