[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-05-11 03:53:58 -04:00 committed by GitHub
parent eea22a56ab
commit 8132365b74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 154 additions and 5 deletions

View File

@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return model
@pytest.fixture(scope="session")
def llama_2_7b_base_huggingface_id():
# used as a base model for testing with sql lora adapter
return "meta-llama/Llama-2-7b-hf"
@pytest.fixture(scope="session")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
@pytest.fixture(scope="session")
def qwen25vl_base_huggingface_id():
# used as a base model for testing with qwen25vl lora adapter
return "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="session")
def qwen25vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
@pytest.fixture
def reset_default_device():
"""
Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture
Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture
helps avoid this problem.
"""
original_device = torch.get_default_device()

View File

@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
VllmConfig)
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.engine.processor import Processor
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
sql_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=llama_2_7b_base_huggingface_id,
tokenizer=llama_2_7b_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(sql_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids = [32000, 32001, 32002, 32003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
lora_request=lora_request)
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens not in the lora adapter should raise an error
invalid_token_ids = [35000, 35001, 35002, 35003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
# tokens in the lora adapter with no lora request should raise an error
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
)
def test_allowed_token_ids_with_lora_adapter_no_vocab(
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=qwen25vl_base_huggingface_id,
tokenizer=qwen25vl_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens in the base model with no lora request should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
)
# tokens not in the base model should raise an error
invalid_token_ids = [200000, 200001, 200002, 200003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)

View File

@ -74,6 +74,7 @@ class Processor:
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
@ -82,7 +83,8 @@ class Processor:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
vocab_size = self.model_config.get_vocab_size()
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
@ -122,6 +124,7 @@ class Processor:
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
@ -132,7 +135,7 @@ class Processor:
raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
@ -207,7 +210,7 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params)
self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None: