diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index b940f7190bb22..399311ce65bb8 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module: return model +@pytest.fixture(scope="session") +def llama_2_7b_base_huggingface_id(): + # used as a base model for testing with sql lora adapter + return "meta-llama/Llama-2-7b-hf" + + @pytest.fixture(scope="session") def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. @@ -198,6 +204,12 @@ def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") +@pytest.fixture(scope="session") +def qwen25vl_base_huggingface_id(): + # used as a base model for testing with qwen25vl lora adapter + return "Qwen/Qwen2.5-VL-3B-Instruct" + + @pytest.fixture(scope="session") def qwen25vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") @@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch): @pytest.fixture def reset_default_device(): """ - Some tests, such as `test_punica_ops.py`, explicitly set the - default device, which can affect subsequent tests. Adding this fixture + Some tests, such as `test_punica_ops.py`, explicitly set the + default device, which can affect subsequent tests. Adding this fixture helps avoid this problem. """ original_device = torch.get_default_device() diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py new file mode 100644 index 0000000000000..094541aef02bb --- /dev/null +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + VllmConfig) +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.v1.engine.processor import Processor + + +def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, + sql_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that define additional tokens. + """ + + # Setup a base model compatible with the sql_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=llama_2_7b_base_huggingface_id, + tokenizer=llama_2_7b_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(sql_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens added in the lora adapter should not raise an error + lora_token_ids = [32000, 32001, 32002, 32003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + lora_request=lora_request) + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens not in the lora adapter should raise an error + invalid_token_ids = [35000, 35001, 35002, 35003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) + + # tokens in the lora adapter with no lora request should raise an error + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + ) + + +def test_allowed_token_ids_with_lora_adapter_no_vocab( + qwen25vl_base_huggingface_id, qwen25vl_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that do not define additional tokens. + """ + + # Setup a base model compatible with the qwen25vl_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=qwen25vl_base_huggingface_id, + tokenizer=qwen25vl_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens in the base model with no lora request should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + ) + + # tokens not in the base model should raise an error + invalid_token_ids = [200000, 200001, 200002, 200003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 27d70a7814719..2aa19f8bbb572 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -74,6 +74,7 @@ class Processor: def _validate_sampling_params( self, params: SamplingParams, + lora_request: Optional[LoRARequest], ) -> None: self._validate_structured_output(params) self._validate_logit_bias(params) @@ -82,7 +83,8 @@ class Processor: return if not params.allowed_token_ids: raise ValueError("allowed_token_ids is not None and empty!") - vocab_size = self.model_config.get_vocab_size() + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + vocab_size = len(tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): raise ValueError( "allowed_token_ids contains out-of-vocab token id!") @@ -122,6 +124,7 @@ class Processor: def _validate_params( self, params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest], ): """ Validate supported SamplingParam. @@ -132,7 +135,7 @@ class Processor: raise ValueError("V1 does not yet support Pooling models.") self._validate_logprobs(params) - self._validate_sampling_params(params) + self._validate_sampling_params(params, lora_request) self._validate_supported_sampling_params(params) def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: @@ -207,7 +210,7 @@ class Processor: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) - self._validate_params(params) + self._validate_params(params, lora_request) if priority != 0: raise ValueError("V1 does not support priority yet.") if trace_headers is not None: