diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index f37686317fd14..2848420c22085 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -213,3 +213,29 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): assert len(num_accepted_tokens_per_pos) == 1 assert isinstance(num_accepted_tokens_per_pos[0], Vector) assert len(num_accepted_tokens_per_pos[0].values) == 5 + + +@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"]) +def test_skip_tokenizer_initialization(model: str, + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + skip_tokenizer_init=True, + enforce_eager=True, + ) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + + with pytest.raises(ValueError, match="cannot pass text prompts when"): + llm.generate("abc", sampling_params) + + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 0f2f404a130ef..224acc47feb27 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -89,6 +89,10 @@ class Processor: return if not params.allowed_token_ids: raise ValueError("allowed_token_ids is not None and empty!") + if self.tokenizer is None: + # When skip_tokenizer_init=True, we can't validate token IDs + # Skip validation and let the model handle invalid tokens + return tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) vocab_size = len(tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): @@ -283,8 +287,9 @@ class Processor: len(decoder_inputs["prompt_token_ids"])) sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) - sampling_params.update_from_tokenizer( - self.tokenizer.get_lora_tokenizer(lora_request)) + if self.tokenizer is not None: + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) else: pooling_params = params.clone()