[Refactor] Pass tokenizer explicitly instead of binding to prompt update (#23542)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-25 21:31:57 +08:00 committed by GitHub
parent e269be2ba2
commit 6879cd80ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 144 deletions

View File

@ -243,7 +243,7 @@ def test_find_token_matches(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0) key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
result = { result = {
@ -392,7 +392,7 @@ def test_find_text_matches(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0) key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
result = { result = {
@ -559,10 +559,8 @@ def test_find_update_text(
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = { mm_prompt_updates = {
key: [[ key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
update_type(key, target, for i in range(mm_count)]
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
@ -731,10 +729,8 @@ def test_find_update_tokens(
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = { mm_prompt_updates = {
key: [[ key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
update_type(key, target, for i in range(mm_count)]
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
@ -879,12 +875,11 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)] key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for i in range(3)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
} }
result = find_mm_placeholders(prompt, mm_prompt_updates) result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)

View File

@ -28,7 +28,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -401,7 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates) repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return { return {
modality: [ modality: [

View File

@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -318,7 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates) repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return { return {
modality: [ modality: [

View File

@ -44,6 +44,44 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text.""" """A token sequence (list of token IDs) or text."""
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
if isinstance(seq, str):
return seq
return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False)
return seq
class _GetMatchIndex(Protocol): class _GetMatchIndex(Protocol):
def __call__( def __call__(
@ -137,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]):
full: _S full: _S
"""The full content.""" """The full content."""
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None is_embed: Optional[Callable[[AnyTokenizer, PromptSeq],
torch.Tensor]] = None
""" """
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions return a boolean mask of shape `(len(full),)` indicating which positions
@ -159,11 +198,12 @@ class PromptUpdateDetails(Generic[_S]):
embed_text: str, embed_text: str,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(full.tokenizer, embed_text) embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full)
return torch.isin( return torch.isin(
torch.tensor(full.token_ids), torch.tensor(token_ids),
torch.tensor(embed_token_ids), torch.tensor(embed_token_ids),
) )
@ -174,10 +214,13 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_token_id: int, embed_token_id: int,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(
full=seq, def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, token_ids = _seq2tokens(tokenizer, full)
)
return torch.tensor(token_ids) == embed_token_id
return PromptUpdateDetails(full=seq, is_embed=is_embed)
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
@ -230,25 +273,14 @@ class PromptUpdate(ABC):
"""Defines how to update the prompt.""" """Defines how to update the prompt."""
raise NotImplementedError raise NotImplementedError
def _resolve_target( def _resolve_target(self, item_idx: int) -> UpdateTarget:
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> Union["_BoundPromptSequence", PromptIndex]:
target = self.target target = self.target
if callable(target): if callable(target):
target = target(item_idx) target = target(item_idx)
if isinstance(target, PromptIndex): return target
return target
return _BoundPromptSequence.from_seq(tokenizer, target) def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
def _resolve_content(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "_BoundPromptContent":
content = self.content content = self.content
if callable(content): if callable(content):
content = content(item_idx) content = content(item_idx)
@ -256,17 +288,9 @@ class PromptUpdate(ABC):
if not isinstance(content, PromptUpdateDetails): if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content) content = PromptUpdateDetails.from_seq(content)
bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full) return content
bound_content = _BoundPromptContent(full=bound_full,
is_embed=content.is_embed)
return bound_content def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
def resolve(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "ResolvedPromptUpdate":
""" """
Given the index of the processed item within Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality], [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
@ -276,8 +300,8 @@ class PromptUpdate(ABC):
modality=self.modality, modality=self.modality,
item_idx=item_idx, item_idx=item_idx,
mode=self.mode, mode=self.mode,
target=self._resolve_target(tokenizer, item_idx), target=self._resolve_target(item_idx),
content=self._resolve_content(tokenizer, item_idx), content=self._resolve_content(item_idx),
) )
@ -424,30 +448,6 @@ class PromptReplacement(PromptUpdate):
return UpdateMode.REPLACE return UpdateMode.REPLACE
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
class _HasModalityAttr(Protocol): class _HasModalityAttr(Protocol):
modality: str modality: str
@ -468,59 +468,6 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
return full_groupby(values, key=lambda x: x.modality) return full_groupby(values, key=lambda x: x.modality)
@dataclass
class _BoundPromptSequence:
"""
A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)
_text: Optional[str]
_token_ids: Optional[list[int]]
@staticmethod
def from_seq(
tokenizer: AnyTokenizer,
seq: PromptSeq,
) -> "_BoundPromptSequence":
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
def __post_init__(self) -> None:
if self._text is None and self._token_ids is None:
raise ValueError("At least one of 'text' and 'token_ids' must be "
"specified")
@property
def text(self) -> str:
if self._text is None:
assert self._token_ids is not None
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
return self._text
@property
def token_ids(self) -> list[int]:
if self._token_ids is None:
assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids
@dataclass
class _BoundPromptContent:
full: _BoundPromptSequence
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
class PromptTargetMatch(NamedTuple): class PromptTargetMatch(NamedTuple):
start_idx: int start_idx: int
end_idx: int end_idx: int
@ -542,10 +489,10 @@ class ResolvedPromptUpdate:
mode: UpdateMode mode: UpdateMode
"""Defines how to update the prompt.""" """Defines how to update the prompt."""
target: Union[_BoundPromptSequence, PromptIndex] target: UpdateTarget
"""The token sequence (or text) to update.""" """The token sequence (or text) to update."""
content: _BoundPromptContent = field(repr=False) content: PromptUpdateDetails = field(repr=False)
"""The placeholder tokens that are part of the update.""" """The placeholder tokens that are part of the update."""
def iter_token_matches( def iter_token_matches(
@ -565,8 +512,10 @@ class ResolvedPromptUpdate:
return return
target_token_ids = _seq2tokens(tokenizer, target)
for match in iter_token_matches(prompt, for match in iter_token_matches(prompt,
target.token_ids, target_token_ids,
start_idx=start_idx): start_idx=start_idx):
yield PromptTargetMatch(match.start_idx, match.end_idx) yield PromptTargetMatch(match.start_idx, match.end_idx)
@ -587,7 +536,9 @@ class ResolvedPromptUpdate:
return return
for match in re.finditer(re.escape(target.text), prompt, target_text = _seq2text(tokenizer, target)
for match in re.finditer(re.escape(target_text), prompt,
pos=start_idx): pos=start_idx):
yield PromptTargetMatch(match.start(), match.end()) yield PromptTargetMatch(match.start(), match.end())
@ -779,7 +730,7 @@ def _apply_matches(
matched_update = mm_prompt_updates[modality][item_idx][ matched_update = mm_prompt_updates[modality][item_idx][
update_idx] update_idx]
matched_content = matched_update.content matched_content = matched_update.content.full
if mode == UpdateMode.INSERT: if mode == UpdateMode.INSERT:
end_idx_to_insert = match.end_idx end_idx_to_insert = match.end_idx
@ -789,8 +740,10 @@ def _apply_matches(
assert_never(mode) assert_never(mode)
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
out_seqs.append(matched_content.full.text if isinstance( out_seqs.append(
prompt, str) else matched_content.full.token_ids) _seq2text(tokenizer, matched_content
) if isinstance(prompt, str) else _seq2tokens(
tokenizer, matched_content))
out_result[modality][item_idx] = update_idx out_result[modality][item_idx] = update_idx
# Exclude overlapping matches # Exclude overlapping matches
@ -842,6 +795,7 @@ def apply_text_matches(
def _iter_placeholders( def _iter_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in `prompt`. Yield each set of placeholder tokens found in `prompt`.
@ -868,7 +822,7 @@ def _iter_placeholders(
for update in modality_updates[item_idx]: for update in modality_updates[item_idx]:
content = update.content content = update.content
content_tokens_full = content.full.token_ids content_tokens_full = _seq2tokens(tokenizer, content.full)
content_len_full = len(content_tokens_full) content_len_full = len(content_tokens_full)
end_idx_full = start_idx + content_len_full end_idx_full = start_idx + content_len_full
@ -878,7 +832,8 @@ def _iter_placeholders(
if prompt[start_idx:end_idx_full] == content_tokens_full: if prompt[start_idx:end_idx_full] == content_tokens_full:
content_is_embed = content.is_embed content_is_embed = content.is_embed
if content_is_embed is not None: if content_is_embed is not None:
content_is_embed = content_is_embed(content.full) content_is_embed = content_is_embed(
tokenizer, content.full)
yield PlaceholderFeaturesInfo( yield PlaceholderFeaturesInfo(
modality=modality, modality=modality,
@ -904,8 +859,9 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates) it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -1160,12 +1116,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_updates: Sequence[PromptUpdate], prompt_updates: Sequence[PromptUpdate],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> MultiModalPromptUpdates: ) -> MultiModalPromptUpdates:
tokenizer = self.info.get_tokenizer()
return { return {
modality: modality: [[update.resolve(item_idx) for update in updates]
[[update.resolve(tokenizer, item_idx) for update in updates] for item_idx in range(mm_item_counts.get(modality, 0))]
for item_idx in range(mm_item_counts.get(modality, 0))]
for modality, updates in full_groupby_modality(prompt_updates) for modality, updates in full_groupby_modality(prompt_updates)
} }
@ -1208,7 +1161,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
new_token_ids: list[int], new_token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates, mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(new_token_ids, mm_prompt_updates) tokenizer = self.info.get_tokenizer()
return find_mm_placeholders(new_token_ids, mm_prompt_updates,
tokenizer)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,