diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8f6d20d11ff3d..7e548bb48b57c 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -10,14 +10,21 @@ from vllm.config import VllmConfig class NgramProposer: def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config + # Minimum length of the n-gram to match. + self.min_n = vllm_config.speculative_config.prompt_lookup_min + # Maximum length of the n-gram to match. + self.max_n = vllm_config.speculative_config.prompt_lookup_max + # Number of tokens follow the match. If there are less than k + # tokens follow the match, we will return the maximum amount of + # tokens until the end. + self.k = vllm_config.speculative_config.num_speculative_tokens + # Trigger Numba JIT compilation for N-gram proposer. + # This usually takes less than 1 second. + self.propose(np.zeros(1024, dtype=np.int32)) def propose( self, context_token_ids: np.ndarray, - min_n: int, - max_n: int, - k: int, ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n @@ -27,17 +34,12 @@ class NgramProposer: Args: context_token_ids: Numpy array of token IDs representing the context sequence. - min_n: Minimum length of the n-gram to match. - max_n: Maximum length of the n-gram to match. - k: Number of tokens follow the match. If there are less - than k tokens follow the match, we will return - the maximum amount of tokens until the end. - + Returns: np.ndarray: The sequence of tokens that followed the matched n-gram in the context. None: If no matching n-gram pattern is found. - + Example: If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and k = 4: @@ -49,8 +51,8 @@ class NgramProposer: we only have three tokens after the match. """ # TODO(woosuk): Optimize this. - for n in range(max_n, min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, k) + for n in range(self.max_n, self.min_n - 1, -1): + result = _find_subarray_kmp(context_token_ids, n, self.k) if result is not None: return result return None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 513806332efe3..82b07c6cd3272 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1246,11 +1246,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :end_idx], - self.speculative_config.prompt_lookup_min, - self.speculative_config.prompt_lookup_max, - self.speculative_config.num_speculative_tokens, - ) + self.input_batch.token_ids_cpu[i, :end_idx]) if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: