mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-26 01:37:04 +08:00
[V1][Spec Decode] Update N-gram Proposer Interface (#15750)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
70ad3f9e98
commit
63375f0cdb
@ -10,14 +10,21 @@ from vllm.config import VllmConfig
|
|||||||
class NgramProposer:
|
class NgramProposer:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.vllm_config = vllm_config
|
# Minimum length of the n-gram to match.
|
||||||
|
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||||
|
# Maximum length of the n-gram to match.
|
||||||
|
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
||||||
|
# Number of tokens follow the match. If there are less than k
|
||||||
|
# tokens follow the match, we will return the maximum amount of
|
||||||
|
# tokens until the end.
|
||||||
|
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
# Trigger Numba JIT compilation for N-gram proposer.
|
||||||
|
# This usually takes less than 1 second.
|
||||||
|
self.propose(np.zeros(1024, dtype=np.int32))
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
context_token_ids: np.ndarray,
|
context_token_ids: np.ndarray,
|
||||||
min_n: int,
|
|
||||||
max_n: int,
|
|
||||||
k: int,
|
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||||
matching in the context. The function finds matches of the last n
|
matching in the context. The function finds matches of the last n
|
||||||
@ -27,17 +34,12 @@ class NgramProposer:
|
|||||||
Args:
|
Args:
|
||||||
context_token_ids: Numpy array of token IDs representing the
|
context_token_ids: Numpy array of token IDs representing the
|
||||||
context sequence.
|
context sequence.
|
||||||
min_n: Minimum length of the n-gram to match.
|
|
||||||
max_n: Maximum length of the n-gram to match.
|
|
||||||
k: Number of tokens follow the match. If there are less
|
|
||||||
than k tokens follow the match, we will return
|
|
||||||
the maximum amount of tokens until the end.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: The sequence of tokens that followed
|
np.ndarray: The sequence of tokens that followed
|
||||||
the matched n-gram in the context.
|
the matched n-gram in the context.
|
||||||
None: If no matching n-gram pattern is found.
|
None: If no matching n-gram pattern is found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
||||||
k = 4:
|
k = 4:
|
||||||
@ -49,8 +51,8 @@ class NgramProposer:
|
|||||||
we only have three tokens after the match.
|
we only have three tokens after the match.
|
||||||
"""
|
"""
|
||||||
# TODO(woosuk): Optimize this.
|
# TODO(woosuk): Optimize this.
|
||||||
for n in range(max_n, min_n - 1, -1):
|
for n in range(self.max_n, self.min_n - 1, -1):
|
||||||
result = _find_subarray_kmp(context_token_ids, n, k)
|
result = _find_subarray_kmp(context_token_ids, n, self.k)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -1246,11 +1246,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
end_idx = start_idx + num_sampled_ids
|
end_idx = start_idx + num_sampled_ids
|
||||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||||
drafter_output = self.drafter.propose(
|
drafter_output = self.drafter.propose(
|
||||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
self.input_batch.token_ids_cpu[i, :end_idx])
|
||||||
self.speculative_config.prompt_lookup_min,
|
|
||||||
self.speculative_config.prompt_lookup_max,
|
|
||||||
self.speculative_config.num_speculative_tokens,
|
|
||||||
)
|
|
||||||
if drafter_output is None or len(drafter_output) == 0:
|
if drafter_output is None or len(drafter_output) == 0:
|
||||||
draft_token_ids.append([])
|
draft_token_ids.append([])
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user