mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:55:01 +08:00
[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
eea22a56ab
commit
8132365b74
@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_2_7b_base_huggingface_id():
|
||||
# used as a base model for testing with sql lora adapter
|
||||
return "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sql_lora_huggingface_id():
|
||||
# huggingface repo id is used to test lora runtime downloading.
|
||||
@ -198,6 +204,12 @@ def qwen2vl_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen25vl_base_huggingface_id():
|
||||
# used as a base model for testing with qwen25vl lora adapter
|
||||
return "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen25vl_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
|
||||
@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
Some tests, such as `test_punica_ops.py`, explicitly set the
|
||||
default device, which can affect subsequent tests. Adding this fixture
|
||||
Some tests, such as `test_punica_ops.py`, explicitly set the
|
||||
default device, which can affect subsequent tests. Adding this fixture
|
||||
helps avoid this problem.
|
||||
"""
|
||||
original_device = torch.get_default_device()
|
||||
|
||||
134
tests/lora/test_lora_allowed_token_ids.py
Normal file
134
tests/lora/test_lora_allowed_token_ids.py
Normal file
@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
VllmConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
|
||||
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
|
||||
sql_lora_files):
|
||||
"""
|
||||
Test that we properly resolve the range of allowed token ids for lora
|
||||
adapters that define additional tokens.
|
||||
"""
|
||||
|
||||
# Setup a base model compatible with the sql_lora_files adapter and
|
||||
# a known number of tokens in the base model.
|
||||
model_config = ModelConfig(
|
||||
model=llama_2_7b_base_huggingface_id,
|
||||
tokenizer=llama_2_7b_base_huggingface_id,
|
||||
tokenizer_mode="auto",
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
device_config=DeviceConfig(),
|
||||
lora_config=LoRAConfig(),
|
||||
)
|
||||
|
||||
tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
processor = Processor(vllm_config, tokenizer)
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(sql_lora_files))
|
||||
request_id = "1"
|
||||
prompt = "a prompt"
|
||||
|
||||
# tokens added in the lora adapter should not raise an error
|
||||
lora_token_ids = [32000, 32001, 32002, 32003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the base model should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens not in the lora adapter should raise an error
|
||||
invalid_token_ids = [35000, 35001, 35002, 35003]
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the lora adapter with no lora request should raise an error
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||
)
|
||||
|
||||
|
||||
def test_allowed_token_ids_with_lora_adapter_no_vocab(
|
||||
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
|
||||
"""
|
||||
Test that we properly resolve the range of allowed token ids for lora
|
||||
adapters that do not define additional tokens.
|
||||
"""
|
||||
|
||||
# Setup a base model compatible with the qwen25vl_lora_files adapter and
|
||||
# a known number of tokens in the base model.
|
||||
model_config = ModelConfig(
|
||||
model=qwen25vl_base_huggingface_id,
|
||||
tokenizer=qwen25vl_base_huggingface_id,
|
||||
tokenizer_mode="auto",
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
device_config=DeviceConfig(),
|
||||
lora_config=LoRAConfig(),
|
||||
)
|
||||
|
||||
tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
processor = Processor(vllm_config, tokenizer)
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
|
||||
request_id = "1"
|
||||
prompt = "a prompt"
|
||||
|
||||
# tokens in the base model should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the base model with no lora request should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
)
|
||||
|
||||
# tokens not in the base model should raise an error
|
||||
invalid_token_ids = [200000, 200001, 200002, 200003]
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||
lora_request=lora_request)
|
||||
@ -74,6 +74,7 @@ class Processor:
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
@ -82,7 +83,8 @@ class Processor:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
raise ValueError("allowed_token_ids is not None and empty!")
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
vocab_size = len(tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
@ -122,6 +124,7 @@ class Processor:
|
||||
def _validate_params(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest],
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
@ -132,7 +135,7 @@ class Processor:
|
||||
raise ValueError("V1 does not yet support Pooling models.")
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params)
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
@ -207,7 +210,7 @@ class Processor:
|
||||
# TODO(woosuk): Support pooling models.
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params)
|
||||
self._validate_params(params, lora_request)
|
||||
if priority != 0:
|
||||
raise ValueError("V1 does not support priority yet.")
|
||||
if trace_headers is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user