mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:14:34 +08:00
[Chore] Enable passing tokenizer=None into MM processor (#29724)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ad7f714d62
commit
fe3398fab2
@ -3,7 +3,6 @@
|
||||
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -24,7 +23,6 @@ from vllm.multimodal.processing import (
|
||||
replace_token_matches,
|
||||
)
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .utils import random_image
|
||||
|
||||
@ -238,15 +236,12 @@ def test_find_token_matches(
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
result = {
|
||||
key: list(update.iter_token_matches(prompt, mock_tokenizer))
|
||||
key: list(update.iter_token_matches(prompt, tokenizer=None))
|
||||
for key, update in prompt_updates.items()
|
||||
}
|
||||
|
||||
@ -385,15 +380,12 @@ def test_find_text_matches(
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
result = {
|
||||
key: list(update.iter_text_matches(prompt, mock_tokenizer))
|
||||
key: list(update.iter_text_matches(prompt, tokenizer=None))
|
||||
for key, update in prompt_updates.items()
|
||||
}
|
||||
|
||||
@ -545,9 +537,6 @@ def test_find_update_text(
|
||||
repl_by_key,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
@ -564,7 +553,7 @@ def test_find_update_text(
|
||||
new_prompt, result = apply_text_matches(
|
||||
prompt,
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
tokenizer=None,
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -750,9 +739,6 @@ def test_find_update_tokens(
|
||||
repl_by_key,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
@ -769,7 +755,7 @@ def test_find_update_tokens(
|
||||
new_prompt, result = apply_token_matches(
|
||||
prompt,
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
tokenizer=None,
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -900,15 +886,12 @@ def test_find_mm_placeholders(
|
||||
expected,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
mm_prompt_updates = {
|
||||
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
|
||||
for key, repl in repl_by_key.items()
|
||||
}
|
||||
|
||||
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
|
||||
result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -1029,12 +1012,9 @@ def test_hf_processor_init_kwargs(
|
||||
inference_kwargs,
|
||||
expected_kwargs,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
|
||||
tokenizer=mock_tokenizer,
|
||||
tokenizer=None,
|
||||
)
|
||||
|
||||
processor = ctx.get_hf_processor(
|
||||
@ -1065,12 +1045,9 @@ def test_hf_processor_call_kwargs(
|
||||
inference_kwargs,
|
||||
expected_kwargs,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
|
||||
tokenizer=mock_tokenizer,
|
||||
tokenizer=None,
|
||||
)
|
||||
|
||||
processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type]
|
||||
@ -1089,8 +1066,6 @@ def test_apply_matches_no_match_exits_quickly():
|
||||
|
||||
With the fix, it should exit immediately when no match is found.
|
||||
"""
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
# Create a long prompt with no placeholder
|
||||
long_prompt = "x" * 10000
|
||||
|
||||
@ -1103,7 +1078,7 @@ def test_apply_matches_no_match_exits_quickly():
|
||||
result, _ = _apply_matches(
|
||||
long_prompt,
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
tokenizer=None,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
|
||||
@ -337,7 +337,7 @@ class OpenAIServing:
|
||||
tokenizer = input_processor.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot use beam search when `skip_tokenizer_init` is True"
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
|
||||
@ -62,7 +62,7 @@ class InputPreprocessor:
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot pass text prompts when `skip_tokenizer_init` is True"
|
||||
"You cannot pass text prompts when `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return self.tokenizer
|
||||
@ -228,22 +228,11 @@ class InputPreprocessor:
|
||||
|
||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
||||
|
||||
def _get_mm_tokenizer(self) -> TokenizerLike:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(TokenizerLike, object()) # Dummy
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
return tokenizer
|
||||
|
||||
def _get_mm_processor(self) -> BaseMultiModalProcessor:
|
||||
if not hasattr(self, "_mm_processor"):
|
||||
tokenizer = self._get_mm_tokenizer()
|
||||
|
||||
self._mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
|
||||
@ -866,12 +866,6 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class Glm4vProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None, "video": 1}
|
||||
|
||||
|
||||
@ -615,9 +615,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
|
||||
return self.get_hf_processor(**kwargs).image_processor
|
||||
|
||||
|
||||
@ -555,7 +555,7 @@ class QwenVLProcessor:
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
tokenizer = self.ctx.get_tokenizer()
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return _get_tokenizer_without_image_pad(tokenizer)
|
||||
|
||||
@ -97,15 +97,37 @@ def _cached_decode(
|
||||
)
|
||||
|
||||
|
||||
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
|
||||
def _seq2text(
|
||||
tokenizer: TokenizerLike | None,
|
||||
seq: PromptSeq,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> str:
|
||||
if isinstance(seq, str):
|
||||
return seq
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`")
|
||||
|
||||
if not use_cache:
|
||||
return decode_tokens(tokenizer, seq)
|
||||
|
||||
return _cached_decode(tokenizer, tuple(seq))
|
||||
|
||||
|
||||
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
|
||||
def _seq2tokens(
|
||||
tokenizer: TokenizerLike | None,
|
||||
seq: PromptSeq,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> list[int]:
|
||||
if isinstance(seq, str):
|
||||
if tokenizer is None:
|
||||
raise ValueError("You cannot encode text when `skip_tokenizer_init=True`")
|
||||
|
||||
if not use_cache:
|
||||
return encode_tokens(tokenizer, seq, add_special_tokens=False)
|
||||
|
||||
return _cached_encode(tokenizer, seq, add_special_tokens=False)
|
||||
|
||||
return seq
|
||||
@ -114,7 +136,7 @@ def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
|
||||
class _GetMatchIndex(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: PromptSeq,
|
||||
start_idx: int = 0,
|
||||
) -> int | None: ...
|
||||
@ -144,7 +166,7 @@ class PromptIndexTargets:
|
||||
"""
|
||||
|
||||
def get_match_index(
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: PromptSeq,
|
||||
start_idx: int = 0,
|
||||
) -> int | None:
|
||||
@ -154,13 +176,11 @@ class PromptIndexTargets:
|
||||
prefix = seq
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if not isinstance(prefix, str):
|
||||
# Make both `str`
|
||||
prefix = decode_tokens(tokenizer, prefix)
|
||||
# Make both `str`
|
||||
prefix = _seq2text(tokenizer, prefix, use_cache=False)
|
||||
else:
|
||||
if isinstance(prefix, str):
|
||||
# Make both `list[int]`
|
||||
prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False)
|
||||
# Make both `list[int]`
|
||||
prefix = _seq2tokens(tokenizer, prefix, use_cache=False)
|
||||
|
||||
match_idx = len(prefix)
|
||||
return match_idx if prompt[:match_idx] == prefix else None
|
||||
@ -200,7 +220,7 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
full: _S
|
||||
"""The full content."""
|
||||
|
||||
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
|
||||
is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None
|
||||
"""
|
||||
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
|
||||
return a boolean mask of shape `(len(full),)` indicating which positions
|
||||
@ -221,8 +241,8 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
seq: _S,
|
||||
embed_text: str,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
|
||||
embed_token_ids = encode_tokens(tokenizer, embed_text)
|
||||
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
|
||||
embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False)
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
return torch.isin(
|
||||
@ -237,7 +257,7 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
seq: _S,
|
||||
embed_token_id: int,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
|
||||
def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
return torch.tensor(token_ids) == embed_token_id
|
||||
@ -523,7 +543,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -545,7 +565,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_text_matches(
|
||||
self,
|
||||
prompt: str,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -567,7 +587,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_matches(
|
||||
self,
|
||||
prompt: list[int] | str,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -676,7 +696,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
|
||||
def _find_matches(
|
||||
prompt: _S,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
prev_end_idx: int = 0,
|
||||
current_result: "MultiModalPromptUpdatesApplyResult",
|
||||
@ -741,7 +761,7 @@ def _all_items_found(
|
||||
def _apply_matches(
|
||||
prompt: _S,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||
|
||||
@ -807,7 +827,7 @@ def _apply_matches(
|
||||
def apply_token_matches(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
|
||||
"""
|
||||
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||
@ -824,7 +844,7 @@ def apply_token_matches(
|
||||
def apply_text_matches(
|
||||
prompt: str,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
|
||||
"""
|
||||
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||
@ -841,7 +861,7 @@ def apply_text_matches(
|
||||
def _iter_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in `prompt`.
|
||||
@ -910,7 +930,7 @@ def _iter_placeholders(
|
||||
def find_mm_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
|
||||
return dict(full_groupby_modality(it))
|
||||
@ -931,9 +951,17 @@ class InputProcessingContext:
|
||||
model_config: ModelConfig
|
||||
"""The configuration of the model."""
|
||||
|
||||
tokenizer: TokenizerLike
|
||||
tokenizer: TokenizerLike | None
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot pass text prompts when `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
@overload
|
||||
def get_hf_config(self, /) -> PretrainedConfig: ...
|
||||
|
||||
@ -1148,7 +1176,7 @@ class BaseProcessingInfo:
|
||||
return self.ctx.model_config.model
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.ctx.tokenizer
|
||||
return self.ctx.get_tokenizer()
|
||||
|
||||
def get_hf_config(self) -> PretrainedConfig:
|
||||
return self.ctx.get_hf_config()
|
||||
@ -1960,15 +1988,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
for update_idxs in match_result.values()
|
||||
):
|
||||
new_text, match_result = self._apply_text_matches(
|
||||
decode_tokens(tokenizer, token_ids),
|
||||
_seq2text(tokenizer, token_ids, use_cache=False),
|
||||
mm_prompt_updates,
|
||||
)
|
||||
|
||||
new_token_ids = encode_tokens(
|
||||
tokenizer,
|
||||
new_text,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False)
|
||||
|
||||
matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
|
||||
for modality, update_idxs in match_result.items():
|
||||
|
||||
@ -234,9 +234,7 @@ class MultiModalRegistry:
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = cast(TokenizerLike, object())
|
||||
elif tokenizer is None:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user