From 6879cd80ae4dba88121db226d5bfbc6a75b072ba Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 25 Aug 2025 21:31:57 +0800 Subject: [PATCH] [Refactor] Pass `tokenizer` explicitly instead of binding to prompt update (#23542) Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 21 +-- vllm/model_executor/models/gemma3_mm.py | 4 +- vllm/model_executor/models/gemma3n_mm.py | 4 +- vllm/multimodal/processing.py | 210 +++++++++-------------- 4 files changed, 95 insertions(+), 144 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 3bebe0ab403c..6ce5fcfe644b 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -243,7 +243,7 @@ def test_find_token_matches( mock_tokenizer = cast(AnyTokenizer, object()) prompt_updates = { - key: update_type(key, target, []).resolve(mock_tokenizer, 0) + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() } result = { @@ -392,7 +392,7 @@ def test_find_text_matches( mock_tokenizer = cast(AnyTokenizer, object()) prompt_updates = { - key: update_type(key, target, []).resolve(mock_tokenizer, 0) + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() } result = { @@ -559,10 +559,8 @@ def test_find_update_text( ) in expected_by_update_type_mm_count.items(): for mm_count, expected in expected_by_mm_count.items(): mm_prompt_updates = { - key: [[ - update_type(key, target, - repl_by_key[key]).resolve(mock_tokenizer, i) - ] for i in range(mm_count)] + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] for key, target in target_by_key.items() } @@ -731,10 +729,8 @@ def test_find_update_tokens( ) in expected_by_update_type_mm_count.items(): for mm_count, expected in expected_by_mm_count.items(): mm_prompt_updates = { - key: [[ - update_type(key, target, - repl_by_key[key]).resolve(mock_tokenizer, i) - ] for i in range(mm_count)] + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] for key, target in target_by_key.items() } @@ -879,12 +875,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)] - for i in range(3)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders(prompt, mm_prompt_updates) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 44188ee4db63..f3dc7dde46bd 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -28,7 +28,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -401,7 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 042c31ba5cc4..d59dde1560ae 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -318,7 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 878e83add861..8c225e2a3c08 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -44,6 +44,44 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + class _GetMatchIndex(Protocol): def __call__( @@ -137,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], + torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -159,11 +198,12 @@ class PromptUpdateDetails(Generic[_S]): embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -174,10 +214,13 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -230,25 +273,14 @@ class PromptUpdate(ABC): """Defines how to update the prompt.""" raise NotImplementedError - def _resolve_target( - self, - tokenizer: AnyTokenizer, - item_idx: int, - ) -> Union["_BoundPromptSequence", PromptIndex]: + def _resolve_target(self, item_idx: int) -> UpdateTarget: target = self.target if callable(target): target = target(item_idx) - if isinstance(target, PromptIndex): - return target + return target - return _BoundPromptSequence.from_seq(tokenizer, target) - - def _resolve_content( - self, - tokenizer: AnyTokenizer, - item_idx: int, - ) -> "_BoundPromptContent": + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: content = self.content if callable(content): content = content(item_idx) @@ -256,17 +288,9 @@ class PromptUpdate(ABC): if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) + return content - return bound_content - - def resolve( - self, - tokenizer: AnyTokenizer, - item_idx: int, - ) -> "ResolvedPromptUpdate": + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": """ Given the index of the processed item within [`modality`][vllm.multimodal.processing.PromptUpdate.modality], @@ -276,8 +300,8 @@ class PromptUpdate(ABC): modality=self.modality, item_idx=item_idx, mode=self.mode, - target=self._resolve_target(tokenizer, item_idx), - content=self._resolve_content(tokenizer, item_idx), + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -424,30 +448,6 @@ class PromptReplacement(PromptUpdate): return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str @@ -468,59 +468,6 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: - """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. - """ - tokenizer: AnyTokenizer = field(repr=False) - - _text: Optional[str] - _token_ids: Optional[list[int]] - - @staticmethod - def from_seq( - tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") - - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) - - return self._text - - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) - - return self._token_ids - - -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] - - class PromptTargetMatch(NamedTuple): start_idx: int end_idx: int @@ -542,10 +489,10 @@ class ResolvedPromptUpdate: mode: UpdateMode """Defines how to update the prompt.""" - target: Union[_BoundPromptSequence, PromptIndex] + target: UpdateTarget """The token sequence (or text) to update.""" - content: _BoundPromptContent = field(repr=False) + content: PromptUpdateDetails = field(repr=False) """The placeholder tokens that are part of the update.""" def iter_token_matches( @@ -565,8 +512,10 @@ class ResolvedPromptUpdate: return + target_token_ids = _seq2tokens(tokenizer, target) + for match in iter_token_matches(prompt, - target.token_ids, + target_token_ids, start_idx=start_idx): yield PromptTargetMatch(match.start_idx, match.end_idx) @@ -587,7 +536,9 @@ class ResolvedPromptUpdate: return - for match in re.finditer(re.escape(target.text), prompt, + target_text = _seq2text(tokenizer, target) + + for match in re.finditer(re.escape(target_text), prompt, pos=start_idx): yield PromptTargetMatch(match.start(), match.end()) @@ -779,7 +730,7 @@ def _apply_matches( matched_update = mm_prompt_updates[modality][item_idx][ update_idx] - matched_content = matched_update.content + matched_content = matched_update.content.full if mode == UpdateMode.INSERT: end_idx_to_insert = match.end_idx @@ -789,8 +740,10 @@ def _apply_matches( assert_never(mode) out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) - out_seqs.append(matched_content.full.text if isinstance( - prompt, str) else matched_content.full.token_ids) + out_seqs.append( + _seq2text(tokenizer, matched_content + ) if isinstance(prompt, str) else _seq2tokens( + tokenizer, matched_content)) out_result[modality][item_idx] = update_idx # Exclude overlapping matches @@ -842,6 +795,7 @@ def apply_text_matches( def _iter_placeholders( prompt: list[int], mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -868,7 +822,7 @@ def _iter_placeholders( for update in modality_updates[item_idx]: content = update.content - content_tokens_full = content.full.token_ids + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -878,7 +832,8 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed( + tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -904,8 +859,9 @@ def _iter_placeholders( def find_mm_placeholders( prompt: list[int], mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(prompt, mm_prompt_updates) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) @@ -1160,12 +1116,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_updates: Sequence[PromptUpdate], mm_item_counts: Mapping[str, int], ) -> MultiModalPromptUpdates: - tokenizer = self.info.get_tokenizer() - return { - modality: - [[update.resolve(tokenizer, item_idx) for update in updates] - for item_idx in range(mm_item_counts.get(modality, 0))] + modality: [[update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0))] for modality, updates in full_groupby_modality(prompt_updates) } @@ -1208,7 +1161,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): new_token_ids: list[int], mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(new_token_ids, mm_prompt_updates) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, + tokenizer) def _get_hf_mm_data( self,