mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 08:29:08 +08:00
[Bugfix] Fix v1/spec_decode/test_ngram.py (#16895)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
fe742aef5a
commit
bb3605db85
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from vllm.config import SpeculativeConfig, VllmConfig
|
||||||
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
||||||
_find_subarray_kmp,
|
_find_subarray_kmp,
|
||||||
_kmp_lps_array)
|
_kmp_lps_array)
|
||||||
@ -39,50 +40,40 @@ def test_find_subarray_kmp():
|
|||||||
|
|
||||||
|
|
||||||
def test_ngram_proposer():
|
def test_ngram_proposer():
|
||||||
proposer = NgramProposer()
|
|
||||||
|
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||||
|
return NgramProposer(vllm_config=VllmConfig(
|
||||||
|
speculative_config=SpeculativeConfig.from_dict(
|
||||||
|
{
|
||||||
|
"prompt_lookup_min": min_n,
|
||||||
|
"prompt_lookup_max": max_n,
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
"method": "ngram",
|
||||||
|
})))
|
||||||
|
|
||||||
# No match.
|
# No match.
|
||||||
result = proposer.propose(
|
result = ngram_proposer(
|
||||||
context_token_ids=np.array([1, 2, 3, 4, 5]),
|
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||||
min_n=2,
|
|
||||||
max_n=2,
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# No match for 4-gram.
|
# No match for 4-gram.
|
||||||
result = proposer.propose(
|
result = ngram_proposer(
|
||||||
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||||
min_n=4,
|
|
||||||
max_n=4,
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# No match for 4-gram but match for 3-gram.
|
# No match for 4-gram but match for 3-gram.
|
||||||
result = proposer.propose(
|
result = ngram_proposer(
|
||||||
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||||
min_n=3,
|
|
||||||
max_n=4,
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
assert np.array_equal(result, np.array([4, 1]))
|
assert np.array_equal(result, np.array([4, 1]))
|
||||||
|
|
||||||
# Match for both 4-gram and 3-gram.
|
# Match for both 4-gram and 3-gram.
|
||||||
# In this case, the proposer should return the 4-gram match.
|
# In this case, the proposer should return the 4-gram match.
|
||||||
result = proposer.propose(
|
result = ngram_proposer(3, 4, 2).propose(
|
||||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
|
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
||||||
min_n=3,
|
|
||||||
max_n=4,
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||||
|
|
||||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||||
result = proposer.propose(
|
result = ngram_proposer(
|
||||||
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
|
2, 4,
|
||||||
min_n=2,
|
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||||
max_n=4,
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
|||||||
def pairwise(iterable):
|
def pairwise(iterable):
|
||||||
"""
|
"""
|
||||||
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
||||||
|
|
||||||
Can be removed when Python 3.9 support is dropped.
|
Can be removed when Python 3.9 support is dropped.
|
||||||
"""
|
"""
|
||||||
iterator = iter(iterable)
|
iterator = iter(iterable)
|
||||||
@ -266,7 +266,7 @@ class ModelConfig:
|
|||||||
config_format: The config format which shall be loaded.
|
config_format: The config format which shall be loaded.
|
||||||
Defaults to 'auto' which defaults to 'hf'.
|
Defaults to 'auto' which defaults to 'hf'.
|
||||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||||
. If `True`, will use the token generated when running
|
. If `True`, will use the token generated when running
|
||||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||||
HuggingFace config. If a callable, it is called to update the
|
HuggingFace config. If a callable, it is called to update the
|
||||||
@ -1624,7 +1624,7 @@ class ParallelConfig:
|
|||||||
"""The full name of the worker class to use. If "auto", the worker class
|
"""The full name of the worker class to use. If "auto", the worker class
|
||||||
will be determined based on the platform."""
|
will be determined based on the platform."""
|
||||||
sd_worker_cls: str = "auto"
|
sd_worker_cls: str = "auto"
|
||||||
"""The full name of the worker class to use for speculative decofing.
|
"""The full name of the worker class to use for speculative decofing.
|
||||||
If "auto", the worker class will be determined based on the platform."""
|
If "auto", the worker class will be determined based on the platform."""
|
||||||
worker_extension_cls: str = ""
|
worker_extension_cls: str = ""
|
||||||
"""The full name of the worker extension class to use. The worker extension
|
"""The full name of the worker extension class to use. The worker extension
|
||||||
@ -1815,13 +1815,13 @@ class SchedulerConfig:
|
|||||||
|
|
||||||
max_num_batched_tokens: int = None # type: ignore
|
max_num_batched_tokens: int = None # type: ignore
|
||||||
"""Maximum number of tokens to be processed in a single iteration.
|
"""Maximum number of tokens to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
max_num_seqs: int = None # type: ignore
|
max_num_seqs: int = None # type: ignore
|
||||||
"""Maximum number of sequences to be processed in a single iteration.
|
"""Maximum number of sequences to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
@ -1867,7 +1867,7 @@ class SchedulerConfig:
|
|||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
max_num_encoder_input_tokens: int = field(init=False)
|
max_num_encoder_input_tokens: int = field(init=False)
|
||||||
"""Multimodal encoder compute budget, only used in V1.
|
"""Multimodal encoder compute budget, only used in V1.
|
||||||
|
|
||||||
NOTE: This is not currently configurable. It will be overridden by
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||||
|
|
||||||
@ -2306,7 +2306,8 @@ class SpeculativeConfig:
|
|||||||
if self.model is None and self.num_speculative_tokens is not None:
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||||
# mtp acceleration for more models besides deepseek_v3
|
# mtp acceleration for more models besides deepseek_v3
|
||||||
if self.target_model_config.hf_text_config.model_type \
|
if self.target_model_config and \
|
||||||
|
self.target_model_config.hf_text_config.model_type \
|
||||||
== "deepseek_v3":
|
== "deepseek_v3":
|
||||||
# use the draft model from the same model:
|
# use the draft model from the same model:
|
||||||
self.model = self.target_model_config.model
|
self.model = self.target_model_config.model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user