[VLM] Simplify post-processing of replacement info (#12269)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-22 08:48:13 +08:00 committed by GitHub
parent 09ccc9c8f7
commit df76e5af26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 175 additions and 208 deletions

View File

@ -35,7 +35,7 @@ def _test_processing_correctness(
task="auto", task="auto",
tokenizer=model_id, tokenizer=model_id,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=True, trust_remote_code=model_info.trust_remote_code,
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,

View File

@ -261,7 +261,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
trust_remote_code=True),
# [Encoder-decoder] # [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501

View File

@ -7,12 +7,16 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement, # yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -433,19 +437,19 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{ {
"pattern_1": [ "pattern_1": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
replacement=[32000, 32000], tokens=[32000, 32000],
), ),
], ],
"pattern_4": [ "pattern_4": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_4", modality="pattern_4",
item_idx=0, item_idx=0,
start_idx=3, start_idx=3,
replacement=[32000], tokens=[32000],
), ),
], ],
} }
@ -455,25 +459,25 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{ {
"pattern_1": [ "pattern_1": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], tokens=[32000, 32000],
), ),
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=5, start_idx=5,
replacement=[32000, 32000], tokens=[32000, 32000],
), ),
], ],
"pattern_3": [ "pattern_3": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0, item_idx=0,
start_idx=7, start_idx=7,
replacement=[1550, 918, 1550], tokens=[1550, 918, 1550],
), ),
], ],
# No match for pattern_4 as it has lower priority than pattern_1 # No match for pattern_4 as it has lower priority than pattern_1
@ -483,33 +487,33 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{ {
"pattern_1": [ "pattern_1": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], tokens=[32000, 32000],
), ),
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=3, start_idx=3,
replacement=[32000, 32000], tokens=[32000, 32000],
), ),
], ],
"pattern_4": [ "pattern_4": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_4", modality="pattern_4",
item_idx=0, item_idx=0,
start_idx=5, start_idx=5,
replacement=[32000], tokens=[32000],
), ),
], ],
"pattern_3": [ "pattern_3": [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
replacement=[1550, 918, 1550], tokens=[1550, 918, 1550],
), ),
], ],
} }

View File

@ -342,13 +342,7 @@ class AriaProcessingInfo(BaseProcessingInfo):
return self.get_hf_config().vision_config return self.get_hf_config().vision_config
def get_hf_processor(self): def get_hf_processor(self):
processor = self.ctx.get_hf_processor(AriaProcessor) return self.ctx.get_hf_processor(AriaProcessor)
# Patch for https://github.com/huggingface/transformers/issues/35768
processor.tokenizer.image_token = "<|img|>"
processor.image_token = "<|img|>"
return processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
@ -381,7 +375,7 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
} }
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore image_token: str = hf_processor.tokenizer.image_token # type: ignore
return ProcessorInputs( return ProcessorInputs(
prompt_text=image_token * num_images, prompt_text=image_token * num_images,

View File

@ -14,12 +14,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs, NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -481,30 +481,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target="</s>", target="</s>",
replacement="<image>" * num_image_tokens + "</s>", replacement=PromptReplacementDetails(
full="<image>" * num_image_tokens + "</s>",
features="<image>" * num_image_tokens,
),
) )
] ]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo, info=Blip2ProcessingInfo,

View File

@ -28,12 +28,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs, NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -141,39 +141,23 @@ class ChameleonMultiModalProcessor(
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens = processor.image_token * self.info.get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target="<image>", target="<image>",
replacement="".join([ replacement=PromptReplacementDetails(
full="".join([
processor.image_start_token, processor.image_start_token,
processor.image_token * self.info.get_num_image_tokens(), image_tokens,
processor.image_end_token, processor.image_end_token,
]), ]),
features=image_tokens,
),
) )
] ]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
class ChameleonLayerNorm(nn.LayerNorm): class ChameleonLayerNorm(nn.LayerNorm):

View File

@ -16,7 +16,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,13 +30,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs, NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -215,9 +215,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
) )
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + return PromptReplacementDetails(
[bos_token_id]) full=image_tokens + [bos_token_id],
features=image_tokens,
)
return [ return [
PromptReplacement( PromptReplacement(
@ -227,26 +231,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) )
] ]
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo, info=FuyuProcessingInfo,

View File

@ -30,15 +30,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs, NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
BoundPromptReplacement, BoundPromptReplacement,
PlaceholderInfo, PromptReplacement) PlaceholderFeaturesInfo,
PromptReplacement,
PromptReplacementDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -437,7 +441,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
num_images = mm_items.get_count("image", strict=False) num_images = mm_items.get_count("image", strict=False)
@ -454,7 +463,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls, mm_prompt_repls=mm_prompt_repls,
@ -467,11 +476,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids = [token_ids[0], *token_ids[2:]] token_ids = [token_ids[0], *token_ids[2:]]
placeholders = { placeholders = {
modality: [ modality: [
PlaceholderInfo( PlaceholderFeaturesInfo(
modality=p.modality, modality=p.modality,
item_idx=p.item_idx, item_idx=p.item_idx,
start_idx=p.start_idx - 1, start_idx=p.start_idx - 1,
replacement=p.replacement, tokens=p.tokens,
) for p in ps ) for p in ps
] ]
for modality, ps in placeholders.items() for modality, ps in placeholders.items()
@ -479,26 +488,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
return token_ids, text, placeholders return token_ids, text, placeholders
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo, info=Phi3VProcessingInfo,

View File

@ -36,13 +36,13 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalInputs, MultiModalKwargs, NestedTensors)
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -216,11 +216,16 @@ class Qwen2AudioMultiModalProcessor(
f"The audio {audio} (len={len(audio)}) is too short " f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model") "to be represented inside the model")
return "".join([ audio_tokens = audio_token * num_placeholders
return PromptReplacementDetails(
full="".join([
audio_bos_token, audio_bos_token,
audio_token * num_placeholders, audio_tokens,
audio_eos_token, audio_eos_token,
]) ]),
features=audio_tokens,
)
return [ return [
PromptReplacement( PromptReplacement(
@ -240,26 +245,6 @@ class Qwen2AudioMultiModalProcessor(
# tokens than the number of audio items) # tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token") return not hasattr(self.info.get_hf_processor(), "audio_token")
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|AUDIO|> tokens should be considered as placeholders,
# so we ignore the audio_bos_token and audio_eos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor, Qwen2AudioMultiModalProcessor,

View File

@ -1,7 +1,8 @@
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
@ -31,6 +32,24 @@ _S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]] _PromptSeq = Union[str, list[int]]
@dataclass
class PromptReplacementDetails:
full: _PromptSeq
"""The full replacement."""
features: _PromptSeq
"""
The part of the replacement that corresponds to placeholder feature tokens.
"""
@staticmethod
def from_seq(seq: _PromptSeq):
return PromptReplacementDetails(full=seq, features=seq)
_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
@dataclass @dataclass
class PromptReplacement: class PromptReplacement:
""" """
@ -43,8 +62,8 @@ class PromptReplacement:
target: _PromptSeq target: _PromptSeq
"""The token sequence (or text) to find and replace.""" """The token sequence (or text) to find and replace."""
replacement: Union[Callable[[int], _PromptSeq], replacement: Union[Callable[[int], _PromptRepl],
_PromptSeq] = field(repr=False) _PromptRepl] = field(repr=False)
""" """
Given the index of the processed item within :attr:`modality`, Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text). output the replacement token sequence (or text).
@ -112,6 +131,14 @@ class _BoundPromptSequence:
_text: Optional[str] _text: Optional[str]
_token_ids: Optional[list[int]] _token_ids: Optional[list[int]]
@staticmethod
def from_seq(tokenizer: AnyTokenizer, seq: _PromptSeq):
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self._text is None and self._token_ids is None: if self._text is None and self._token_ids is None:
raise ValueError("At least one of 'text' and 'token_ids' must be " raise ValueError("At least one of 'text' and 'token_ids' must be "
@ -134,6 +161,12 @@ class _BoundPromptSequence:
return self._token_ids return self._token_ids
@dataclass
class _BoundPromptReplacementGroup:
full: _BoundPromptSequence
features: _BoundPromptSequence
@dataclass @dataclass
class BoundPromptReplacement: class BoundPromptReplacement:
""" """
@ -145,24 +178,18 @@ class BoundPromptReplacement:
modality: str modality: str
_target: _PromptSeq _target: _PromptSeq
_replacement: Union[Callable[[int], _PromptSeq], _replacement: Union[Callable[[int], _PromptRepl],
_PromptSeq] = field(repr=False) _PromptRepl] = field(repr=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptSequence]() self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
@property @property
def target(self) -> _BoundPromptSequence: def target(self) -> _BoundPromptSequence:
"""The token sequence (or text) to find and replace.""" """The token sequence (or text) to find and replace."""
target = self._target return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
return _BoundPromptSequence( def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
tokenizer=self.tokenizer,
_text=target if isinstance(target, str) else None,
_token_ids=target if isinstance(target, list) else None,
)
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
""" """
Given the index of the processed item within :attr:`modality`, Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text). output the replacement token sequence (or text).
@ -177,10 +204,16 @@ class BoundPromptReplacement:
else: else:
cache_key = None cache_key = None
bound_replacement = _BoundPromptSequence( if not isinstance(replacement, PromptReplacementDetails):
tokenizer=self.tokenizer, replacement = PromptReplacementDetails.from_seq(replacement)
_text=replacement if isinstance(replacement, str) else None,
_token_ids=replacement if isinstance(replacement, list) else None, bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.full)
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.features)
bound_replacement = _BoundPromptReplacementGroup(
full=bound_full,
features=bound_features,
) )
if cache_key is not None: if cache_key is not None:
@ -197,7 +230,7 @@ class _TokenMatch(NamedTuple):
def iter_token_matches( def iter_token_matches(
token_ids: list[int], token_ids: list[int],
match_ids: list[int], match_ids: list[int],
) -> Iterable[_TokenMatch]: ) -> Generator[_TokenMatch]:
""" """
Yield each occurrence of :code:`match_ids` in :code:`token_ids`. Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
@ -272,15 +305,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass @dataclass
class PlaceholderInfo: class PlaceholderFeaturesInfo:
modality: str modality: str
item_idx: int item_idx: int
start_idx: int start_idx: int
replacement: list[int] tokens: list[int]
@property @property
def length(self) -> int: def length(self) -> int:
return len(self.replacement) return len(self.tokens)
def to_range(self) -> PlaceholderRange: def to_range(self) -> PlaceholderRange:
return PlaceholderRange( return PlaceholderRange(
@ -362,10 +395,10 @@ def _replace_matches(
replacement = repl_info.get_replacement(item_idx) replacement = repl_info.get_replacement(item_idx)
if isinstance(prompt, str): if isinstance(prompt, str):
repl_seq = replacement.text repl_seq = replacement.full.text
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
else: else:
repl_seq = replacement.token_ids repl_seq = replacement.full.token_ids
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
prev_end_idx = end_idx prev_end_idx = end_idx
@ -408,7 +441,7 @@ def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in :code:`prompt`. Yield each set of placeholder tokens found in :code:`prompt`.
@ -432,23 +465,33 @@ def _iter_placeholders(
for repl_info in modality_repls: for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx) replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids repl_tokens_full = replacement.full.token_ids
repl_len = len(repl_tokens) repl_len_full = len(repl_tokens_full)
end_idx = start_idx + repl_len end_idx_full = start_idx + repl_len_full
if repl_len == 0 or end_idx > prompt_len: if repl_len_full == 0 or end_idx_full > prompt_len:
continue continue
if prompt[start_idx:end_idx] == repl_tokens: if prompt[start_idx:end_idx_full] == repl_tokens_full:
yield PlaceholderInfo( repl_tokens_feat = replacement.features.token_ids
try:
match = next(
iter_token_matches(repl_tokens_full,
repl_tokens_feat))
yield PlaceholderFeaturesInfo(
modality=modality, modality=modality,
item_idx=item_idx, item_idx=item_idx,
start_idx=start_idx, start_idx=start_idx + match.start_idx,
replacement=repl_tokens, tokens=repl_tokens_feat,
) )
except StopIteration:
raise AssertionError(
f"{repl_tokens_feat=} should be a "
f"subsequence of {repl_tokens_full=}") from None
# Exclude overlapping matches # Exclude overlapping matches
start_idx = end_idx start_idx = end_idx_full
item_idx_by_modality[modality] += 1 item_idx_by_modality[modality] += 1
found = True found = True
break break
@ -464,7 +507,7 @@ def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -679,7 +722,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids, return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts) mm_item_counts)
@ -948,7 +991,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
mm_token_matches = { mm_token_matches = {
@ -1037,7 +1080,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_placeholders( def _validate_mm_placeholders(
self, self,
mm_placeholders: Mapping[str, list[PlaceholderInfo]], mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
*, *,
allow_missing: bool = False, allow_missing: bool = False,