[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-05-11 03:53:58 -04:00 committed by GitHub
parent eea22a56ab
commit 8132365b74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 154 additions and 5 deletions

View File

@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return model return model
@pytest.fixture(scope="session")
def llama_2_7b_base_huggingface_id():
# used as a base model for testing with sql lora adapter
return "meta-llama/Llama-2-7b-hf"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sql_lora_huggingface_id(): def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading. # huggingface repo id is used to test lora runtime downloading.
@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
@pytest.fixture(scope="session")
def qwen25vl_base_huggingface_id():
# used as a base model for testing with qwen25vl lora adapter
return "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def qwen25vl_lora_files(): def qwen25vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
@pytest.fixture @pytest.fixture
def reset_default_device(): def reset_default_device():
""" """
Some tests, such as `test_punica_ops.py`, explicitly set the Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture default device, which can affect subsequent tests. Adding this fixture
helps avoid this problem. helps avoid this problem.
""" """
original_device = torch.get_default_device() original_device = torch.get_default_device()

View File

@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
VllmConfig)
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.engine.processor import Processor
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
sql_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=llama_2_7b_base_huggingface_id,
tokenizer=llama_2_7b_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(sql_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids = [32000, 32001, 32002, 32003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
lora_request=lora_request)
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens not in the lora adapter should raise an error
invalid_token_ids = [35000, 35001, 35002, 35003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
# tokens in the lora adapter with no lora request should raise an error
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
)
def test_allowed_token_ids_with_lora_adapter_no_vocab(
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=qwen25vl_base_huggingface_id,
tokenizer=qwen25vl_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens in the base model with no lora request should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
)
# tokens not in the base model should raise an error
invalid_token_ids = [200000, 200001, 200002, 200003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)

View File

@ -74,6 +74,7 @@ class Processor:
def _validate_sampling_params( def _validate_sampling_params(
self, self,
params: SamplingParams, params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None: ) -> None:
self._validate_structured_output(params) self._validate_structured_output(params)
self._validate_logit_bias(params) self._validate_logit_bias(params)
@ -82,7 +83,8 @@ class Processor:
return return
if not params.allowed_token_ids: if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!") raise ValueError("allowed_token_ids is not None and empty!")
vocab_size = self.model_config.get_vocab_size() tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError( raise ValueError(
"allowed_token_ids contains out-of-vocab token id!") "allowed_token_ids contains out-of-vocab token id!")
@ -122,6 +124,7 @@ class Processor:
def _validate_params( def _validate_params(
self, self,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
): ):
""" """
Validate supported SamplingParam. Validate supported SamplingParam.
@ -132,7 +135,7 @@ class Processor:
raise ValueError("V1 does not yet support Pooling models.") raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params) self._validate_logprobs(params)
self._validate_sampling_params(params) self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params) self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
@ -207,7 +210,7 @@ class Processor:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models. # TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params) self._validate_params(params, lora_request)
if priority != 0: if priority != 0:
raise ValueError("V1 does not support priority yet.") raise ValueError("V1 does not support priority yet.")
if trace_headers is not None: if trace_headers is not None: