[VLM] Simplify post-processing of replacement info (#12269)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-22 08:48:13 +08:00 committed by GitHub
parent 09ccc9c8f7
commit df76e5af26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 175 additions and 208 deletions

View File

@ -35,7 +35,7 @@ def _test_processing_correctness(
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=True,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
revision=None,

View File

@ -261,7 +261,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
trust_remote_code=True),
# [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501

View File

@ -7,12 +7,16 @@ import pytest
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -433,19 +437,19 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{
"pattern_1": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=0,
start_idx=6,
replacement=[32000, 32000],
tokens=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_4",
item_idx=0,
start_idx=3,
replacement=[32000],
tokens=[32000],
),
],
}
@ -455,25 +459,25 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
tokens=[32000, 32000],
),
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
replacement=[32000, 32000],
tokens=[32000, 32000],
),
],
"pattern_3": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
replacement=[1550, 918, 1550],
tokens=[1550, 918, 1550],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
@ -483,33 +487,33 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
tokens=[32000, 32000],
),
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
replacement=[32000, 32000],
tokens=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_4",
item_idx=0,
start_idx=5,
replacement=[32000],
tokens=[32000],
),
],
"pattern_3": [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
replacement=[1550, 918, 1550],
tokens=[1550, 918, 1550],
),
],
}

View File

@ -342,13 +342,7 @@ class AriaProcessingInfo(BaseProcessingInfo):
return self.get_hf_config().vision_config
def get_hf_processor(self):
processor = self.ctx.get_hf_processor(AriaProcessor)
# Patch for https://github.com/huggingface/transformers/issues/35768
processor.tokenizer.image_token = "<|img|>"
processor.image_token = "<|img|>"
return processor
return self.ctx.get_hf_processor(AriaProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@ -381,7 +375,7 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
}
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore
image_token: str = hf_processor.tokenizer.image_token # type: ignore
return ProcessorInputs(
prompt_text=image_token * num_images,

View File

@ -14,12 +14,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -481,30 +481,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
PromptReplacement(
modality="image",
target="</s>",
replacement="<image>" * num_image_tokens + "</s>",
replacement=PromptReplacementDetails(
full="<image>" * num_image_tokens + "</s>",
features="<image>" * num_image_tokens,
),
)
]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,

View File

@ -28,12 +28,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -141,39 +141,23 @@ class ChameleonMultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens = processor.image_token * self.info.get_num_image_tokens()
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self.info.get_num_image_tokens(),
processor.image_end_token,
]),
replacement=PromptReplacementDetails(
full="".join([
processor.image_start_token,
image_tokens,
processor.image_end_token,
]),
features=image_tokens,
),
)
]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
class ChameleonLayerNorm(nn.LayerNorm):

View File

@ -16,7 +16,7 @@
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
TypedDict)
import torch
import torch.nn as nn
@ -30,13 +30,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -215,9 +215,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
[bos_token_id])
return PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
return [
PromptReplacement(
@ -227,26 +231,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
)
]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,

View File

@ -30,15 +30,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
BoundPromptReplacement,
PlaceholderInfo, PromptReplacement)
PlaceholderFeaturesInfo,
PromptReplacement,
PromptReplacementDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -437,7 +441,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
processor=hf_processor,
)
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
num_images = mm_items.get_count("image", strict=False)
@ -454,7 +463,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
@ -467,11 +476,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
PlaceholderInfo(
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
tokens=p.tokens,
) for p in ps
]
for modality, ps in placeholders.items()
@ -479,26 +488,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
return token_ids, text, placeholders
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,

View File

@ -36,13 +36,13 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -216,11 +216,16 @@ class Qwen2AudioMultiModalProcessor(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
return "".join([
audio_bos_token,
audio_token * num_placeholders,
audio_eos_token,
])
audio_tokens = audio_token * num_placeholders
return PromptReplacementDetails(
full="".join([
audio_bos_token,
audio_tokens,
audio_eos_token,
]),
features=audio_tokens,
)
return [
PromptReplacement(
@ -240,26 +245,6 @@ class Qwen2AudioMultiModalProcessor(
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|AUDIO|> tokens should be considered as placeholders,
# so we ignore the audio_bos_token and audio_eos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,

View File

@ -1,7 +1,8 @@
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence)
from dataclasses import dataclass, field
from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
@ -31,6 +32,24 @@ _S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]]
@dataclass
class PromptReplacementDetails:
full: _PromptSeq
"""The full replacement."""
features: _PromptSeq
"""
The part of the replacement that corresponds to placeholder feature tokens.
"""
@staticmethod
def from_seq(seq: _PromptSeq):
return PromptReplacementDetails(full=seq, features=seq)
_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
@dataclass
class PromptReplacement:
"""
@ -43,8 +62,8 @@ class PromptReplacement:
target: _PromptSeq
"""The token sequence (or text) to find and replace."""
replacement: Union[Callable[[int], _PromptSeq],
_PromptSeq] = field(repr=False)
replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
@ -112,6 +131,14 @@ class _BoundPromptSequence:
_text: Optional[str]
_token_ids: Optional[list[int]]
@staticmethod
def from_seq(tokenizer: AnyTokenizer, seq: _PromptSeq):
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
def __post_init__(self) -> None:
if self._text is None and self._token_ids is None:
raise ValueError("At least one of 'text' and 'token_ids' must be "
@ -134,6 +161,12 @@ class _BoundPromptSequence:
return self._token_ids
@dataclass
class _BoundPromptReplacementGroup:
full: _BoundPromptSequence
features: _BoundPromptSequence
@dataclass
class BoundPromptReplacement:
"""
@ -145,24 +178,18 @@ class BoundPromptReplacement:
modality: str
_target: _PromptSeq
_replacement: Union[Callable[[int], _PromptSeq],
_PromptSeq] = field(repr=False)
_replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptSequence]()
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
@property
def target(self) -> _BoundPromptSequence:
"""The token sequence (or text) to find and replace."""
target = self._target
return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
return _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=target if isinstance(target, str) else None,
_token_ids=target if isinstance(target, list) else None,
)
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
@ -177,10 +204,16 @@ class BoundPromptReplacement:
else:
cache_key = None
bound_replacement = _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=replacement if isinstance(replacement, str) else None,
_token_ids=replacement if isinstance(replacement, list) else None,
if not isinstance(replacement, PromptReplacementDetails):
replacement = PromptReplacementDetails.from_seq(replacement)
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.full)
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.features)
bound_replacement = _BoundPromptReplacementGroup(
full=bound_full,
features=bound_features,
)
if cache_key is not None:
@ -197,7 +230,7 @@ class _TokenMatch(NamedTuple):
def iter_token_matches(
token_ids: list[int],
match_ids: list[int],
) -> Iterable[_TokenMatch]:
) -> Generator[_TokenMatch]:
"""
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
@ -272,15 +305,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass
class PlaceholderInfo:
class PlaceholderFeaturesInfo:
modality: str
item_idx: int
start_idx: int
replacement: list[int]
tokens: list[int]
@property
def length(self) -> int:
return len(self.replacement)
return len(self.tokens)
def to_range(self) -> PlaceholderRange:
return PlaceholderRange(
@ -362,10 +395,10 @@ def _replace_matches(
replacement = repl_info.get_replacement(item_idx)
if isinstance(prompt, str):
repl_seq = replacement.text
repl_seq = replacement.full.text
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
else:
repl_seq = replacement.token_ids
repl_seq = replacement.full.token_ids
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
prev_end_idx = end_idx
@ -408,7 +441,7 @@ def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
) -> Iterable[PlaceholderFeaturesInfo]:
"""
Yield each set of placeholder tokens found in :code:`prompt`.
@ -432,23 +465,33 @@ def _iter_placeholders(
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len
repl_tokens_full = replacement.full.token_ids
repl_len_full = len(repl_tokens_full)
end_idx_full = start_idx + repl_len_full
if repl_len == 0 or end_idx > prompt_len:
if repl_len_full == 0 or end_idx_full > prompt_len:
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
if prompt[start_idx:end_idx_full] == repl_tokens_full:
repl_tokens_feat = replacement.features.token_ids
try:
match = next(
iter_token_matches(repl_tokens_full,
repl_tokens_feat))
yield PlaceholderFeaturesInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx + match.start_idx,
tokens=repl_tokens_feat,
)
except StopIteration:
raise AssertionError(
f"{repl_tokens_feat=} should be a "
f"subsequence of {repl_tokens_full=}") from None
# Exclude overlapping matches
start_idx = end_idx
start_idx = end_idx_full
item_idx_by_modality[modality] += 1
found = True
break
@ -464,7 +507,7 @@ def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@ -679,7 +722,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts)
@ -948,7 +991,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer()
mm_token_matches = {
@ -1037,7 +1080,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[PlaceholderInfo]],
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,