[Chore] Enable passing tokenizer=None into MM processor (#29724)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-29 22:25:10 +08:00 committed by GitHub
parent ad7f714d62
commit fe3398fab2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 68 additions and 91 deletions

View File

@ -3,7 +3,6 @@
import time
from contextlib import nullcontext
from typing import cast
import numpy as np
import pytest
@ -24,7 +23,6 @@ from vllm.multimodal.processing import (
replace_token_matches,
)
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.tokenizers import TokenizerLike
from .utils import random_image
@ -238,15 +236,12 @@ def test_find_token_matches(
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer))
key: list(update.iter_token_matches(prompt, tokenizer=None))
for key, update in prompt_updates.items()
}
@ -385,15 +380,12 @@ def test_find_text_matches(
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer))
key: list(update.iter_text_matches(prompt, tokenizer=None))
for key, update in prompt_updates.items()
}
@ -545,9 +537,6 @@ def test_find_update_text(
repl_by_key,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
expected_by_mm_count,
@ -564,7 +553,7 @@ def test_find_update_text(
new_prompt, result = apply_text_matches(
prompt,
mm_prompt_updates,
mock_tokenizer,
tokenizer=None,
)
# Only displayed on error
@ -750,9 +739,6 @@ def test_find_update_tokens(
repl_by_key,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
expected_by_mm_count,
@ -769,7 +755,7 @@ def test_find_update_tokens(
new_prompt, result = apply_token_matches(
prompt,
mm_prompt_updates,
mock_tokenizer,
tokenizer=None,
)
# Only displayed on error
@ -900,15 +886,12 @@ def test_find_mm_placeholders(
expected,
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None)
# Only displayed on error
print("result:", result)
@ -1029,12 +1012,9 @@ def test_hf_processor_init_kwargs(
inference_kwargs,
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer,
tokenizer=None,
)
processor = ctx.get_hf_processor(
@ -1065,12 +1045,9 @@ def test_hf_processor_call_kwargs(
inference_kwargs,
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer,
tokenizer=None,
)
processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type]
@ -1089,8 +1066,6 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found.
"""
mock_tokenizer = cast(TokenizerLike, object())
# Create a long prompt with no placeholder
long_prompt = "x" * 10000
@ -1103,7 +1078,7 @@ def test_apply_matches_no_match_exits_quickly():
result, _ = _apply_matches(
long_prompt,
mm_prompt_updates,
mock_tokenizer,
tokenizer=None,
)
elapsed = time.perf_counter() - start

View File

@ -337,7 +337,7 @@ class OpenAIServing:
tokenizer = input_processor.tokenizer
if tokenizer is None:
raise ValueError(
"You cannot use beam search when `skip_tokenizer_init` is True"
"You cannot use beam search when `skip_tokenizer_init=True`"
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore

View File

@ -62,7 +62,7 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True"
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return self.tokenizer
@ -228,22 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer()
return tokenizer
def _get_mm_processor(self) -> BaseMultiModalProcessor:
if not hasattr(self, "_mm_processor"):
tokenizer = self._get_mm_tokenizer()
self._mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)

View File

@ -866,12 +866,6 @@ class Glm4vVisionTransformer(nn.Module):
class Glm4vProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_tokenizer(self):
return self.ctx.tokenizer
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1}

View File

@ -615,9 +615,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
**kwargs,
)
def get_tokenizer(self):
return self.ctx.tokenizer
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor

View File

@ -555,7 +555,7 @@ class QwenVLProcessor:
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = self.ctx.tokenizer
tokenizer = self.ctx.get_tokenizer()
assert isinstance(tokenizer, PreTrainedTokenizer)
return _get_tokenizer_without_image_pad(tokenizer)

View File

@ -97,15 +97,37 @@ def _cached_decode(
)
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
def _seq2text(
tokenizer: TokenizerLike | None,
seq: PromptSeq,
*,
use_cache: bool = True,
) -> str:
if isinstance(seq, str):
return seq
if tokenizer is None:
raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`")
if not use_cache:
return decode_tokens(tokenizer, seq)
return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
def _seq2tokens(
tokenizer: TokenizerLike | None,
seq: PromptSeq,
*,
use_cache: bool = True,
) -> list[int]:
if isinstance(seq, str):
if tokenizer is None:
raise ValueError("You cannot encode text when `skip_tokenizer_init=True`")
if not use_cache:
return encode_tokens(tokenizer, seq, add_special_tokens=False)
return _cached_encode(tokenizer, seq, add_special_tokens=False)
return seq
@ -114,7 +136,7 @@ def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol):
def __call__(
self,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None: ...
@ -144,7 +166,7 @@ class PromptIndexTargets:
"""
def get_match_index(
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None:
@ -154,13 +176,11 @@ class PromptIndexTargets:
prefix = seq
if isinstance(prompt, str):
if not isinstance(prefix, str):
# Make both `str`
prefix = decode_tokens(tokenizer, prefix)
# Make both `str`
prefix = _seq2text(tokenizer, prefix, use_cache=False)
else:
if isinstance(prefix, str):
# Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False)
# Make both `list[int]`
prefix = _seq2tokens(tokenizer, prefix, use_cache=False)
match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None
@ -200,7 +220,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S
"""The full content."""
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None
"""
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
@ -221,8 +241,8 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_text: str,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text)
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False)
token_ids = _seq2tokens(tokenizer, full)
return torch.isin(
@ -237,7 +257,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_token_id: int,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id
@ -523,7 +543,7 @@ class ResolvedPromptUpdate:
def iter_token_matches(
self,
prompt: list[int],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -545,7 +565,7 @@ class ResolvedPromptUpdate:
def iter_text_matches(
self,
prompt: str,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -567,7 +587,7 @@ class ResolvedPromptUpdate:
def iter_matches(
self,
prompt: list[int] | str,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -676,7 +696,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult",
@ -741,7 +761,7 @@ def _all_items_found(
def _apply_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
@ -807,7 +827,7 @@ def _apply_matches(
def apply_token_matches(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@ -824,7 +844,7 @@ def apply_token_matches(
def apply_text_matches(
prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@ -841,7 +861,7 @@ def apply_text_matches(
def _iter_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
) -> Iterable[PlaceholderFeaturesInfo]:
"""
Yield each set of placeholder tokens found in `prompt`.
@ -910,7 +930,7 @@ def _iter_placeholders(
def find_mm_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it))
@ -931,9 +951,17 @@ class InputProcessingContext:
model_config: ModelConfig
"""The configuration of the model."""
tokenizer: TokenizerLike
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return self.tokenizer
@overload
def get_hf_config(self, /) -> PretrainedConfig: ...
@ -1148,7 +1176,7 @@ class BaseProcessingInfo:
return self.ctx.model_config.model
def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer
return self.ctx.get_tokenizer()
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
@ -1960,15 +1988,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for update_idxs in match_result.values()
):
new_text, match_result = self._apply_text_matches(
decode_tokens(tokenizer, token_ids),
_seq2text(tokenizer, token_ids, use_cache=False),
mm_prompt_updates,
)
new_token_ids = encode_tokens(
tokenizer,
new_text,
add_special_tokens=False,
)
new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False)
matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
for modality, update_idxs in match_result.items():

View File

@ -234,9 +234,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)