mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
eea22a56ab
commit
8132365b74
@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def llama_2_7b_base_huggingface_id():
|
||||||
|
# used as a base model for testing with sql lora adapter
|
||||||
|
return "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sql_lora_huggingface_id():
|
def sql_lora_huggingface_id():
|
||||||
# huggingface repo id is used to test lora runtime downloading.
|
# huggingface repo id is used to test lora runtime downloading.
|
||||||
@ -198,6 +204,12 @@ def qwen2vl_lora_files():
|
|||||||
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
|
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def qwen25vl_base_huggingface_id():
|
||||||
|
# used as a base model for testing with qwen25vl lora adapter
|
||||||
|
return "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def qwen25vl_lora_files():
|
def qwen25vl_lora_files():
|
||||||
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
|
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
|
||||||
@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reset_default_device():
|
def reset_default_device():
|
||||||
"""
|
"""
|
||||||
Some tests, such as `test_punica_ops.py`, explicitly set the
|
Some tests, such as `test_punica_ops.py`, explicitly set the
|
||||||
default device, which can affect subsequent tests. Adding this fixture
|
default device, which can affect subsequent tests. Adding this fixture
|
||||||
helps avoid this problem.
|
helps avoid this problem.
|
||||||
"""
|
"""
|
||||||
original_device = torch.get_default_device()
|
original_device = torch.get_default_device()
|
||||||
|
|||||||
134
tests/lora/test_lora_allowed_token_ids.py
Normal file
134
tests/lora/test_lora_allowed_token_ids.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
|
VllmConfig)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
from vllm.v1.engine.processor import Processor
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
|
||||||
|
sql_lora_files):
|
||||||
|
"""
|
||||||
|
Test that we properly resolve the range of allowed token ids for lora
|
||||||
|
adapters that define additional tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup a base model compatible with the sql_lora_files adapter and
|
||||||
|
# a known number of tokens in the base model.
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=llama_2_7b_base_huggingface_id,
|
||||||
|
tokenizer=llama_2_7b_base_huggingface_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=CacheConfig(),
|
||||||
|
device_config=DeviceConfig(),
|
||||||
|
lora_config=LoRAConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = init_tokenizer_from_configs(
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
|
lora_config=vllm_config.lora_config)
|
||||||
|
processor = Processor(vllm_config, tokenizer)
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, str(sql_lora_files))
|
||||||
|
request_id = "1"
|
||||||
|
prompt = "a prompt"
|
||||||
|
|
||||||
|
# tokens added in the lora adapter should not raise an error
|
||||||
|
lora_token_ids = [32000, 32001, 32002, 32003]
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
# tokens in the base model should not raise an error
|
||||||
|
base_token_ids = [1000, 1001, 1002, 1003]
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
# tokens not in the lora adapter should raise an error
|
||||||
|
invalid_token_ids = [35000, 35001, 35002, 35003]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
# tokens in the lora adapter with no lora request should raise an error
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_token_ids_with_lora_adapter_no_vocab(
|
||||||
|
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
|
||||||
|
"""
|
||||||
|
Test that we properly resolve the range of allowed token ids for lora
|
||||||
|
adapters that do not define additional tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup a base model compatible with the qwen25vl_lora_files adapter and
|
||||||
|
# a known number of tokens in the base model.
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=qwen25vl_base_huggingface_id,
|
||||||
|
tokenizer=qwen25vl_base_huggingface_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=CacheConfig(),
|
||||||
|
device_config=DeviceConfig(),
|
||||||
|
lora_config=LoRAConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = init_tokenizer_from_configs(
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
|
lora_config=vllm_config.lora_config)
|
||||||
|
processor = Processor(vllm_config, tokenizer)
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
|
||||||
|
request_id = "1"
|
||||||
|
prompt = "a prompt"
|
||||||
|
|
||||||
|
# tokens in the base model should not raise an error
|
||||||
|
base_token_ids = [1000, 1001, 1002, 1003]
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
# tokens in the base model with no lora request should not raise an error
|
||||||
|
base_token_ids = [1000, 1001, 1002, 1003]
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
# tokens not in the base model should raise an error
|
||||||
|
invalid_token_ids = [200000, 200001, 200002, 200003]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||||
|
lora_request=lora_request)
|
||||||
@ -74,6 +74,7 @@ class Processor:
|
|||||||
def _validate_sampling_params(
|
def _validate_sampling_params(
|
||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._validate_structured_output(params)
|
self._validate_structured_output(params)
|
||||||
self._validate_logit_bias(params)
|
self._validate_logit_bias(params)
|
||||||
@ -82,7 +83,8 @@ class Processor:
|
|||||||
return
|
return
|
||||||
if not params.allowed_token_ids:
|
if not params.allowed_token_ids:
|
||||||
raise ValueError("allowed_token_ids is not None and empty!")
|
raise ValueError("allowed_token_ids is not None and empty!")
|
||||||
vocab_size = self.model_config.get_vocab_size()
|
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||||
|
vocab_size = len(tokenizer)
|
||||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"allowed_token_ids contains out-of-vocab token id!")
|
"allowed_token_ids contains out-of-vocab token id!")
|
||||||
@ -122,6 +124,7 @@ class Processor:
|
|||||||
def _validate_params(
|
def _validate_params(
|
||||||
self,
|
self,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Validate supported SamplingParam.
|
Validate supported SamplingParam.
|
||||||
@ -132,7 +135,7 @@ class Processor:
|
|||||||
raise ValueError("V1 does not yet support Pooling models.")
|
raise ValueError("V1 does not yet support Pooling models.")
|
||||||
|
|
||||||
self._validate_logprobs(params)
|
self._validate_logprobs(params)
|
||||||
self._validate_sampling_params(params)
|
self._validate_sampling_params(params, lora_request)
|
||||||
self._validate_supported_sampling_params(params)
|
self._validate_supported_sampling_params(params)
|
||||||
|
|
||||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||||
@ -207,7 +210,7 @@ class Processor:
|
|||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
# TODO(woosuk): Support encoder-decoder models.
|
# TODO(woosuk): Support encoder-decoder models.
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_params(params)
|
self._validate_params(params, lora_request)
|
||||||
if priority != 0:
|
if priority != 0:
|
||||||
raise ValueError("V1 does not support priority yet.")
|
raise ValueError("V1 does not support priority yet.")
|
||||||
if trace_headers is not None:
|
if trace_headers is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user