[Bugfix] Fix v1/spec_decode/test_ngram.py (#16895)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi 2025-04-20 20:54:29 -07:00 committed by GitHub
parent fe742aef5a
commit bb3605db85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 38 deletions

View File

@ -2,6 +2,7 @@
import numpy as np import numpy as np
from vllm.config import SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp, _find_subarray_kmp,
_kmp_lps_array) _kmp_lps_array)
@ -39,50 +40,40 @@ def test_find_subarray_kmp():
def test_ngram_proposer(): def test_ngram_proposer():
proposer = NgramProposer()
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
return NgramProposer(vllm_config=VllmConfig(
speculative_config=SpeculativeConfig.from_dict(
{
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match. # No match.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 5]), 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
min_n=2,
max_n=2,
k=2,
)
assert result is None assert result is None
# No match for 4-gram. # No match for 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=4,
max_n=4,
k=2,
)
assert result is None assert result is None
# No match for 4-gram but match for 3-gram. # No match for 4-gram but match for 3-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([4, 1])) assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram. # Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match. # In this case, the proposer should return the 4-gram match.
result = proposer.propose( result = ngram_proposer(3, 4, 2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram. # Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), 2, 4,
min_n=2, 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]

View File

@ -120,7 +120,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def pairwise(iterable): def pairwise(iterable):
""" """
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped. Can be removed when Python 3.9 support is dropped.
""" """
iterator = iter(iterable) iterator = iter(iterable)
@ -266,7 +266,7 @@ class ModelConfig:
config_format: The config format which shall be loaded. config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
hf_token: The token to use as HTTP bearer authorization for remote files hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running . If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`). `huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the HuggingFace config. If a callable, it is called to update the
@ -1624,7 +1624,7 @@ class ParallelConfig:
"""The full name of the worker class to use. If "auto", the worker class """The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform.""" will be determined based on the platform."""
sd_worker_cls: str = "auto" sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing. """The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform.""" If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = "" worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension """The full name of the worker extension class to use. The worker extension
@ -1815,13 +1815,13 @@ class SchedulerConfig:
max_num_batched_tokens: int = None # type: ignore max_num_batched_tokens: int = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration. """Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: int = None # type: ignore max_num_seqs: int = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration. """Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
@ -1867,7 +1867,7 @@ class SchedulerConfig:
# TODO (ywang96): Make this configurable. # TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False) max_num_encoder_input_tokens: int = field(init=False)
"""Multimodal encoder compute budget, only used in V1. """Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger.""" max_num_batched_tokens in case max multimodal embedding size is larger."""
@ -2306,7 +2306,8 @@ class SpeculativeConfig:
if self.model is None and self.num_speculative_tokens is not None: if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting # TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3 # mtp acceleration for more models besides deepseek_v3
if self.target_model_config.hf_text_config.model_type \ if self.target_model_config and \
self.target_model_config.hf_text_config.model_type \
== "deepseek_v3": == "deepseek_v3":
# use the draft model from the same model: # use the draft model from the same model:
self.model = self.target_model_config.model self.model = self.target_model_config.model