mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
111 lines
5.3 KiB
Python
111 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import numpy as np
|
|
|
|
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
|
from vllm.v1.spec_decode.ngram_proposer import (
|
|
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
|
|
|
|
|
|
def test_find_longest_matched_ngram_and_propose_tokens():
|
|
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
|
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=2,
|
|
max_ngram=2,
|
|
max_model_len=1024,
|
|
k=2) is None
|
|
|
|
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=2,
|
|
max_ngram=2,
|
|
max_model_len=1024,
|
|
k=3),
|
|
np.array([4, 1, 2]))
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=2,
|
|
max_ngram=2,
|
|
max_model_len=1024,
|
|
k=2), np.array([4, 1]))
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=1,
|
|
max_ngram=1,
|
|
max_model_len=1024,
|
|
k=3),
|
|
np.array([4, 1, 2]))
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=1,
|
|
max_ngram=1,
|
|
max_model_len=1024,
|
|
k=2), np.array([4, 1]))
|
|
|
|
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=2,
|
|
max_ngram=2,
|
|
max_model_len=1024,
|
|
k=3),
|
|
np.array([4, 1, 2]))
|
|
# Return on the first match
|
|
np.testing.assert_array_equal(
|
|
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
|
min_ngram=1,
|
|
max_ngram=1,
|
|
max_model_len=1024,
|
|
k=2), np.array([6, 2]))
|
|
|
|
|
|
def test_ngram_proposer():
|
|
|
|
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
|
# Dummy model config. Just to set max_model_len.
|
|
model_config = ModelConfig(model="facebook/opt-125m")
|
|
return NgramProposer(
|
|
vllm_config=VllmConfig(model_config=model_config,
|
|
speculative_config=SpeculativeConfig(
|
|
prompt_lookup_min=min_n,
|
|
prompt_lookup_max=max_n,
|
|
num_speculative_tokens=k,
|
|
method="ngram",
|
|
)))
|
|
|
|
# No match.
|
|
result = ngram_proposer(
|
|
min_n=2, max_n=2,
|
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
|
assert result is None
|
|
|
|
# No match for 4-gram.
|
|
result = ngram_proposer(
|
|
min_n=4, max_n=4,
|
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
|
assert result is None
|
|
|
|
# No match for 4-gram but match for 3-gram.
|
|
result = ngram_proposer(
|
|
min_n=3, max_n=4,
|
|
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
|
assert np.array_equal(result, np.array([4, 1]))
|
|
|
|
# Match for both 4-gram and 3-gram.
|
|
# In this case, the proposer should return the 4-gram match.
|
|
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
|
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
|
|
|
# Match for 2-gram and 3-gram, but not 4-gram.
|
|
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
|
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
|
|
|
# Multiple 3-gram matched, but always pick the first one.
|
|
result = ngram_proposer(
|
|
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
|
|
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
|
|
assert np.array_equal(result, np.array([100, 1]))
|