From fe3398fab2b11f92b4ada209a85558710c446a36 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 29 Nov 2025 22:25:10 +0800 Subject: [PATCH] [Chore] Enable passing `tokenizer=None` into MM processor (#29724) Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 41 +++-------- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/inputs/preprocess.py | 15 +--- vllm/model_executor/models/glm4_1v.py | 6 -- vllm/model_executor/models/qwen3_vl.py | 3 - vllm/model_executor/models/qwen_vl.py | 2 +- vllm/multimodal/processing.py | 86 +++++++++++++++-------- vllm/multimodal/registry.py | 4 +- 8 files changed, 68 insertions(+), 91 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index f7fa8da54d54e..262ea42e4d0fa 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,7 +3,6 @@ import time from contextlib import nullcontext -from typing import cast import numpy as np import pytest @@ -24,7 +23,6 @@ from vllm.multimodal.processing import ( replace_token_matches, ) from vllm.multimodal.profiling import MultiModalProfiler -from vllm.tokenizers import TokenizerLike from .utils import random_image @@ -238,15 +236,12 @@ def test_find_token_matches( expected_by_key, update_type, ): - # Should not be used since there is nothing to convert to token IDs - mock_tokenizer = cast(TokenizerLike, object()) - prompt_updates = { key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() } result = { - key: list(update.iter_token_matches(prompt, mock_tokenizer)) + key: list(update.iter_token_matches(prompt, tokenizer=None)) for key, update in prompt_updates.items() } @@ -385,15 +380,12 @@ def test_find_text_matches( expected_by_key, update_type, ): - # Should not be used since there is nothing to convert to text - mock_tokenizer = cast(TokenizerLike, object()) - prompt_updates = { key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() } result = { - key: list(update.iter_text_matches(prompt, mock_tokenizer)) + key: list(update.iter_text_matches(prompt, tokenizer=None)) for key, update in prompt_updates.items() } @@ -545,9 +537,6 @@ def test_find_update_text( repl_by_key, expected_by_update_type_mm_count, ): - # Should not be used since there is nothing to convert to text - mock_tokenizer = cast(TokenizerLike, object()) - for ( update_type, expected_by_mm_count, @@ -564,7 +553,7 @@ def test_find_update_text( new_prompt, result = apply_text_matches( prompt, mm_prompt_updates, - mock_tokenizer, + tokenizer=None, ) # Only displayed on error @@ -750,9 +739,6 @@ def test_find_update_tokens( repl_by_key, expected_by_update_type_mm_count, ): - # Should not be used since there is nothing to convert to tokens - mock_tokenizer = cast(TokenizerLike, object()) - for ( update_type, expected_by_mm_count, @@ -769,7 +755,7 @@ def test_find_update_tokens( new_prompt, result = apply_token_matches( prompt, mm_prompt_updates, - mock_tokenizer, + tokenizer=None, ) # Only displayed on error @@ -900,15 +886,12 @@ def test_find_mm_placeholders( expected, update_type, ): - # Should not be used since there is nothing to convert to tokens - mock_tokenizer = cast(TokenizerLike, object()) - mm_prompt_updates = { key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) + result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None) # Only displayed on error print("result:", result) @@ -1029,12 +1012,9 @@ def test_hf_processor_init_kwargs( inference_kwargs, expected_kwargs, ): - # Should not be used since there is nothing to convert to tokens - mock_tokenizer = cast(TokenizerLike, object()) - ctx = InputProcessingContext( model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), - tokenizer=mock_tokenizer, + tokenizer=None, ) processor = ctx.get_hf_processor( @@ -1065,12 +1045,9 @@ def test_hf_processor_call_kwargs( inference_kwargs, expected_kwargs, ): - # Should not be used since there is nothing to convert to tokens - mock_tokenizer = cast(TokenizerLike, object()) - ctx = InputProcessingContext( model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), - tokenizer=mock_tokenizer, + tokenizer=None, ) processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type] @@ -1089,8 +1066,6 @@ def test_apply_matches_no_match_exits_quickly(): With the fix, it should exit immediately when no match is found. """ - mock_tokenizer = cast(TokenizerLike, object()) - # Create a long prompt with no placeholder long_prompt = "x" * 10000 @@ -1103,7 +1078,7 @@ def test_apply_matches_no_match_exits_quickly(): result, _ = _apply_matches( long_prompt, mm_prompt_updates, - mock_tokenizer, + tokenizer=None, ) elapsed = time.perf_counter() - start diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e7a632e025103..b6a2478cf8c81 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -337,7 +337,7 @@ class OpenAIServing: tokenizer = input_processor.tokenizer if tokenizer is None: raise ValueError( - "You cannot use beam search when `skip_tokenizer_init` is True" + "You cannot use beam search when `skip_tokenizer_init=True`" ) eos_token_id: int = tokenizer.eos_token_id # type: ignore diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 46d1bed38aa85..2893a56b1190f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -62,7 +62,7 @@ class InputPreprocessor: def get_tokenizer(self) -> TokenizerLike: if self.tokenizer is None: raise ValueError( - "You cannot pass text prompts when `skip_tokenizer_init` is True" + "You cannot pass text prompts when `skip_tokenizer_init=True`" ) return self.tokenizer @@ -228,22 +228,11 @@ class InputPreprocessor: return tokenizer.encode(prompt, **tokenization_kwargs) - def _get_mm_tokenizer(self) -> TokenizerLike: - # PrithviGeoSpatialMAE needs to be initialized without a tokenizer - # while using also multi-modal input - if not self.tokenizer: - return cast(TokenizerLike, object()) # Dummy - - tokenizer = self.get_tokenizer() - return tokenizer - def _get_mm_processor(self) -> BaseMultiModalProcessor: if not hasattr(self, "_mm_processor"): - tokenizer = self._get_mm_tokenizer() - self._mm_processor = self.mm_registry.create_processor( self.model_config, - tokenizer=tokenizer, + tokenizer=self.tokenizer, cache=self.mm_processor_cache, ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index fe238861ecceb..5ba3c0a35928d 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -866,12 +866,6 @@ class Glm4vVisionTransformer(nn.Module): class Glm4vProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): - return self.ctx.get_hf_config() - - def get_tokenizer(self): - return self.ctx.tokenizer - def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f0ba631e66807..1d3929b936a9f 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -615,9 +615,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): **kwargs, ) - def get_tokenizer(self): - return self.ctx.tokenizer - def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: return self.get_hf_processor(**kwargs).image_processor diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 4906cf441f6fb..55680b8e7ddfd 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -555,7 +555,7 @@ class QwenVLProcessor: class QwenVLProcessingInfo(BaseProcessingInfo): def get_tokenizer(self) -> PreTrainedTokenizer: - tokenizer = self.ctx.tokenizer + tokenizer = self.ctx.get_tokenizer() assert isinstance(tokenizer, PreTrainedTokenizer) return _get_tokenizer_without_image_pad(tokenizer) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index aab657b24ba23..912cff2343dd0 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -97,15 +97,37 @@ def _cached_decode( ) -def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str: +def _seq2text( + tokenizer: TokenizerLike | None, + seq: PromptSeq, + *, + use_cache: bool = True, +) -> str: if isinstance(seq, str): return seq + if tokenizer is None: + raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`") + + if not use_cache: + return decode_tokens(tokenizer, seq) + return _cached_decode(tokenizer, tuple(seq)) -def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]: +def _seq2tokens( + tokenizer: TokenizerLike | None, + seq: PromptSeq, + *, + use_cache: bool = True, +) -> list[int]: if isinstance(seq, str): + if tokenizer is None: + raise ValueError("You cannot encode text when `skip_tokenizer_init=True`") + + if not use_cache: + return encode_tokens(tokenizer, seq, add_special_tokens=False) + return _cached_encode(tokenizer, seq, add_special_tokens=False) return seq @@ -114,7 +136,7 @@ def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]: class _GetMatchIndex(Protocol): def __call__( self, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, prompt: PromptSeq, start_idx: int = 0, ) -> int | None: ... @@ -144,7 +166,7 @@ class PromptIndexTargets: """ def get_match_index( - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, prompt: PromptSeq, start_idx: int = 0, ) -> int | None: @@ -154,13 +176,11 @@ class PromptIndexTargets: prefix = seq if isinstance(prompt, str): - if not isinstance(prefix, str): - # Make both `str` - prefix = decode_tokens(tokenizer, prefix) + # Make both `str` + prefix = _seq2text(tokenizer, prefix, use_cache=False) else: - if isinstance(prefix, str): - # Make both `list[int]` - prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False) + # Make both `list[int]` + prefix = _seq2tokens(tokenizer, prefix, use_cache=False) match_idx = len(prefix) return match_idx if prompt[:match_idx] == prefix else None @@ -200,7 +220,7 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None + is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -221,8 +241,8 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor: - embed_token_ids = encode_tokens(tokenizer, embed_text) + def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor: + embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False) token_ids = _seq2tokens(tokenizer, full) return torch.isin( @@ -237,7 +257,7 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor: + def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor: token_ids = _seq2tokens(tokenizer, full) return torch.tensor(token_ids) == embed_token_id @@ -523,7 +543,7 @@ class ResolvedPromptUpdate: def iter_token_matches( self, prompt: list[int], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, start_idx: int = 0, ) -> Generator[PromptTargetMatch]: @@ -545,7 +565,7 @@ class ResolvedPromptUpdate: def iter_text_matches( self, prompt: str, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, start_idx: int = 0, ) -> Generator[PromptTargetMatch]: @@ -567,7 +587,7 @@ class ResolvedPromptUpdate: def iter_matches( self, prompt: list[int] | str, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, start_idx: int = 0, ) -> Generator[PromptTargetMatch]: @@ -676,7 +696,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] def _find_matches( prompt: _S, mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, prev_end_idx: int = 0, current_result: "MultiModalPromptUpdatesApplyResult", @@ -741,7 +761,7 @@ def _all_items_found( def _apply_matches( prompt: _S, mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} @@ -807,7 +827,7 @@ def _apply_matches( def apply_token_matches( prompt: list[int], mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, ) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: """ Apply the updates in `mm_prompt_updates` to `prompt`. @@ -824,7 +844,7 @@ def apply_token_matches( def apply_text_matches( prompt: str, mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, ) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: """ Apply the updates in `mm_prompt_updates` to `prompt`. @@ -841,7 +861,7 @@ def apply_text_matches( def _iter_placeholders( prompt: list[int], mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -910,7 +930,7 @@ def _iter_placeholders( def find_mm_placeholders( prompt: list[int], mm_prompt_updates: "MultiModalPromptUpdates", - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) @@ -931,9 +951,17 @@ class InputProcessingContext: model_config: ModelConfig """The configuration of the model.""" - tokenizer: TokenizerLike + tokenizer: TokenizerLike | None """The tokenizer used to tokenize the inputs.""" + def get_tokenizer(self) -> TokenizerLike: + if self.tokenizer is None: + raise ValueError( + "You cannot pass text prompts when `skip_tokenizer_init=True`" + ) + + return self.tokenizer + @overload def get_hf_config(self, /) -> PretrainedConfig: ... @@ -1148,7 +1176,7 @@ class BaseProcessingInfo: return self.ctx.model_config.model def get_tokenizer(self) -> TokenizerLike: - return self.ctx.tokenizer + return self.ctx.get_tokenizer() def get_hf_config(self) -> PretrainedConfig: return self.ctx.get_hf_config() @@ -1960,15 +1988,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): for update_idxs in match_result.values() ): new_text, match_result = self._apply_text_matches( - decode_tokens(tokenizer, token_ids), + _seq2text(tokenizer, token_ids, use_cache=False), mm_prompt_updates, ) - new_token_ids = encode_tokens( - tokenizer, - new_text, - add_special_tokens=False, - ) + new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False) matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list) for modality, update_idxs in match_result.items(): diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ee90570b24aaf..2fdae46e547b0 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -234,9 +234,7 @@ class MultiModalRegistry: model_config: "ModelConfig", tokenizer: TokenizerLike | None = None, ) -> InputProcessingContext: - if model_config.skip_tokenizer_init: - tokenizer = cast(TokenizerLike, object()) - elif tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) return InputProcessingContext(model_config, tokenizer)