diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 50a273016ab80..77379cc8de904 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,6 @@ from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs pytestmark = pytest.mark.cpu_test @@ -107,8 +106,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_text_no_mm_inputs(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_configs(model_config) - input_preprocessor = InputPreprocessor(model_config, tokenizer) + input_preprocessor = InputPreprocessor(model_config) with pytest.raises(ValueError, match="does not support multimodal inputs"): input_preprocessor.preprocess(prompt) @@ -129,8 +127,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_configs(model_config) - input_preprocessor = InputPreprocessor(model_config, tokenizer) + input_preprocessor = InputPreprocessor(model_config) + tokenizer = input_preprocessor.tokenizer # HF processor adds sep token sep_token_id = tokenizer.vocab[tokenizer.sep_token] diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index 9c29c42f5465a..2f73756ff6152 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -65,9 +65,7 @@ def _mk_processor( device_config=DeviceConfig(device="cpu"), ) - # Pass tokenizer=None; InputPreprocessor handles None when - # skip_tokenizer_init is True. - return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type] + return Processor(vllm_config) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d98318f856f2d..9afbf8b7e1b85 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -74,7 +74,6 @@ from vllm.transformers_utils.tokenizer import ( AnyTokenizer, MistralTokenizer, get_cached_tokenizer, - init_tokenizer_from_configs, ) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, as_iter, is_list_of @@ -367,11 +366,8 @@ class LLM: def _get_processor(self) -> Processor: if not hasattr(self, "_processor"): vllm_config = self.llm_engine.vllm_config - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_configs(self.model_config) - self._processor = Processor(vllm_config, tokenizer) + self._processor = Processor(vllm_config) + return self._processor def get_default_sampling_params(self) -> SamplingParams: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 596ae3fcdc3c4..6ddde23b4a343 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,7 +16,6 @@ from starlette.datastructures import Headers from typing_extensions import TypeIs from vllm.entrypoints.utils import _validate_truncation_size -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.processor import Processor @@ -272,11 +271,8 @@ class OpenAIServing: async def _get_processor(self) -> Processor: if not hasattr(self, "_processor"): vllm_config = await self.engine_client.get_vllm_config() - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_configs(self.model_config) - self._processor = Processor(vllm_config, tokenizer) + self._processor = Processor(vllm_config) + return self._processor def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d1f55740149ae..00f30e483693e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,7 +17,8 @@ from vllm.multimodal.inputs import ( MultiModalUUIDDict, ) from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs +from vllm.utils.jsontree import json_iter_leaves from .data import ( DecoderOnlyInputs, @@ -44,17 +45,20 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: Optional[AnyTokenizer], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config - self.tokenizer = tokenizer self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache + if model_config.skip_tokenizer_init: + self.tokenizer = None + else: + self.tokenizer = init_tokenizer_from_configs(model_config) + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError( @@ -273,7 +277,10 @@ class InputPreprocessor: mm_hashes = mm_input["mm_hashes"] # Validate that all mm items have a string as their hash - if not contains_only_strings(mm_hashes): + contains_only_strings = all( + isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes) + ) + if not contains_only_strings: raise ValueError( f"mm_hashes must contain only strings, got: {mm_hashes}. " "This is likely due to an incorrect custom implementation of " @@ -693,15 +700,3 @@ class InputPreprocessor: def clear_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() - - -# Helper function to validate that a nested dictionary contains -# only strings or list of strings as the leaf values. -def contains_only_strings(obj: object): - if isinstance(obj, str): - return True - if isinstance(obj, list): - return all(isinstance(x, str) for x in obj) - if isinstance(obj, dict): - return all(contains_only_strings(v) for v in obj.values()) - return False diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ca668bc217e1a..5be1f833e3f63 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.v1.engine import EngineCoreRequest @@ -104,20 +104,8 @@ class AsyncLLM(EngineClient): "logger list; enabling logging without default stat loggers" ) - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config - ) - # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor( - vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry, - ) + self.processor = Processor(vllm_config, mm_registry=mm_registry) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor( @@ -257,6 +245,10 @@ class AsyncLLM(EngineClient): cancel_task_threadsafe(getattr(self, "output_handler", None)) + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 9da25c0662a82..701a625805628 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,7 +23,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine import EngineCoreRequest @@ -95,18 +95,8 @@ class LLMEngine: self.dp_group = None self.should_execute_dummy_batch = False - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config - ) - # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor( - vllm_config=vllm_config, tokenizer=self.tokenizer, mm_registry=mm_registry - ) + self.processor = Processor(vllm_config, mm_registry=mm_registry) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor( @@ -214,6 +204,14 @@ class LLMEngine: def validate_outputs(cls, outputs, output_type): return outputs + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.processor.tokenizer = tokenizer + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 8a6ac0927e6d8..f39e9c1eea7d4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -37,15 +37,13 @@ class Processor: def __init__( self, vllm_config: VllmConfig, - tokenizer: AnyTokenizer, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): + ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.structured_outputs_config = vllm_config.structured_outputs_config - self.tokenizer = tokenizer self.generation_config_fields = self.model_config.try_get_generation_config() @@ -54,11 +52,18 @@ class Processor: self.input_preprocessor = InputPreprocessor( self.model_config, - self.tokenizer, mm_registry, mm_processor_cache=self.mm_processor_cache, ) + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.input_preprocessor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.input_preprocessor.tokenizer = tokenizer + def _validate_logprobs( self, params: SamplingParams, @@ -511,10 +516,8 @@ class Processor: else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = self.tokenizer + tokenizer = self.tokenizer + if tokenizer is not None: max_input_id = max(prompt_ids or [], default=0) # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while