mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:55:46 +08:00
[Bugfix] Fix v1/spec_decode/test_ngram.py (#16895)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
fe742aef5a
commit
bb3605db85
@ -2,6 +2,7 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import SpeculativeConfig, VllmConfig
|
||||
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
||||
_find_subarray_kmp,
|
||||
_kmp_lps_array)
|
||||
@ -39,50 +40,40 @@ def test_find_subarray_kmp():
|
||||
|
||||
|
||||
def test_ngram_proposer():
|
||||
proposer = NgramProposer()
|
||||
|
||||
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
return NgramProposer(vllm_config=VllmConfig(
|
||||
speculative_config=SpeculativeConfig.from_dict(
|
||||
{
|
||||
"prompt_lookup_min": min_n,
|
||||
"prompt_lookup_max": max_n,
|
||||
"num_speculative_tokens": k,
|
||||
"method": "ngram",
|
||||
})))
|
||||
|
||||
# No match.
|
||||
result = proposer.propose(
|
||||
context_token_ids=np.array([1, 2, 3, 4, 5]),
|
||||
min_n=2,
|
||||
max_n=2,
|
||||
k=2,
|
||||
)
|
||||
result = ngram_proposer(
|
||||
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||
assert result is None
|
||||
|
||||
# No match for 4-gram.
|
||||
result = proposer.propose(
|
||||
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
||||
min_n=4,
|
||||
max_n=4,
|
||||
k=2,
|
||||
)
|
||||
result = ngram_proposer(
|
||||
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert result is None
|
||||
|
||||
# No match for 4-gram but match for 3-gram.
|
||||
result = proposer.propose(
|
||||
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
||||
min_n=3,
|
||||
max_n=4,
|
||||
k=2,
|
||||
)
|
||||
result = ngram_proposer(
|
||||
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert np.array_equal(result, np.array([4, 1]))
|
||||
|
||||
# Match for both 4-gram and 3-gram.
|
||||
# In this case, the proposer should return the 4-gram match.
|
||||
result = proposer.propose(
|
||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
|
||||
min_n=3,
|
||||
max_n=4,
|
||||
k=2,
|
||||
)
|
||||
result = ngram_proposer(3, 4, 2).propose(
|
||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||
|
||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||
result = proposer.propose(
|
||||
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
|
||||
min_n=2,
|
||||
max_n=4,
|
||||
k=2,
|
||||
)
|
||||
result = ngram_proposer(
|
||||
2, 4,
|
||||
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||
|
||||
@ -2306,7 +2306,8 @@ class SpeculativeConfig:
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
# mtp acceleration for more models besides deepseek_v3
|
||||
if self.target_model_config.hf_text_config.model_type \
|
||||
if self.target_model_config and \
|
||||
self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user