mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[VLM] Simplify post-processing of replacement info (#12269)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
09ccc9c8f7
commit
df76e5af26
@ -35,7 +35,7 @@ def _test_processing_correctness(
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
|
||||
@ -261,7 +261,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||
|
||||
@ -7,12 +7,16 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -433,19 +437,19 @@ def test_find_replace_tokens(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_4",
|
||||
item_idx=0,
|
||||
start_idx=3,
|
||||
replacement=[32000],
|
||||
tokens=[32000],
|
||||
),
|
||||
],
|
||||
}
|
||||
@ -455,25 +459,25 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
tokens=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
# No match for pattern_4 as it has lower priority than pattern_1
|
||||
@ -483,33 +487,33 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_4",
|
||||
item_idx=0,
|
||||
start_idx=5,
|
||||
replacement=[32000],
|
||||
tokens=[32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
tokens=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
@ -342,13 +342,7 @@ class AriaProcessingInfo(BaseProcessingInfo):
|
||||
return self.get_hf_config().vision_config
|
||||
|
||||
def get_hf_processor(self):
|
||||
processor = self.ctx.get_hf_processor(AriaProcessor)
|
||||
|
||||
# Patch for https://github.com/huggingface/transformers/issues/35768
|
||||
processor.tokenizer.image_token = "<|img|>"
|
||||
processor.image_token = "<|img|>"
|
||||
|
||||
return processor
|
||||
return self.ctx.get_hf_processor(AriaProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
@ -381,7 +375,7 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
|
||||
}
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_token: str = hf_processor.image_token # type: ignore
|
||||
image_token: str = hf_processor.tokenizer.image_token # type: ignore
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
|
||||
@ -14,12 +14,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -481,30 +481,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement="<image>" * num_image_tokens + "</s>",
|
||||
replacement=PromptReplacementDetails(
|
||||
full="<image>" * num_image_tokens + "</s>",
|
||||
features="<image>" * num_image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <image> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
|
||||
info=Blip2ProcessingInfo,
|
||||
|
||||
@ -28,12 +28,12 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -141,39 +141,23 @@ class ChameleonMultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_tokens = processor.image_token * self.info.get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * self.info.get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
replacement=PromptReplacementDetails(
|
||||
full="".join([
|
||||
processor.image_start_token,
|
||||
image_tokens,
|
||||
processor.image_end_token,
|
||||
]),
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <image> tokens should be considered as placeholders,
|
||||
# so we ignore the image_start_token and image_end_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"] + 1,
|
||||
length=p["length"] - 2) for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ChameleonLayerNorm(nn.LayerNorm):
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
TypedDict)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -30,13 +30,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -215,9 +215,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
|
||||
[bos_token_id])
|
||||
return PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -227,26 +231,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only |SPEAKER| (image) tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
|
||||
info=FuyuProcessingInfo,
|
||||
|
||||
@ -30,15 +30,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
BoundPromptReplacement,
|
||||
PlaceholderInfo, PromptReplacement)
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@ -437,7 +441,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
||||
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
|
||||
num_images = mm_items.get_count("image", strict=False)
|
||||
|
||||
@ -454,7 +463,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
token_ids: list[int],
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
mm_prompt_repls=mm_prompt_repls,
|
||||
@ -467,11 +476,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
token_ids = [token_ids[0], *token_ids[2:]]
|
||||
placeholders = {
|
||||
modality: [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=p.start_idx - 1,
|
||||
replacement=p.replacement,
|
||||
tokens=p.tokens,
|
||||
) for p in ps
|
||||
]
|
||||
for modality, ps in placeholders.items()
|
||||
@ -479,26 +488,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|image|> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
|
||||
info=Phi3VProcessingInfo,
|
||||
|
||||
@ -36,13 +36,13 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -216,11 +216,16 @@ class Qwen2AudioMultiModalProcessor(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
return "".join([
|
||||
audio_bos_token,
|
||||
audio_token * num_placeholders,
|
||||
audio_eos_token,
|
||||
])
|
||||
audio_tokens = audio_token * num_placeholders
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full="".join([
|
||||
audio_bos_token,
|
||||
audio_tokens,
|
||||
audio_eos_token,
|
||||
]),
|
||||
features=audio_tokens,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -240,26 +245,6 @@ class Qwen2AudioMultiModalProcessor(
|
||||
# tokens than the number of audio items)
|
||||
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|AUDIO|> tokens should be considered as placeholders,
|
||||
# so we ignore the audio_bos_token and audio_eos_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"] + 1,
|
||||
length=p["length"] - 2) for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2AudioMultiModalProcessor,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
||||
Sequence)
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
@ -31,6 +32,24 @@ _S = TypeVar("_S", str, list[int])
|
||||
_PromptSeq = Union[str, list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacementDetails:
|
||||
full: _PromptSeq
|
||||
"""The full replacement."""
|
||||
|
||||
features: _PromptSeq
|
||||
"""
|
||||
The part of the replacement that corresponds to placeholder feature tokens.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_seq(seq: _PromptSeq):
|
||||
return PromptReplacementDetails(full=seq, features=seq)
|
||||
|
||||
|
||||
_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacement:
|
||||
"""
|
||||
@ -43,8 +62,8 @@ class PromptReplacement:
|
||||
target: _PromptSeq
|
||||
"""The token sequence (or text) to find and replace."""
|
||||
|
||||
replacement: Union[Callable[[int], _PromptSeq],
|
||||
_PromptSeq] = field(repr=False)
|
||||
replacement: Union[Callable[[int], _PromptRepl],
|
||||
_PromptRepl] = field(repr=False)
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the replacement token sequence (or text).
|
||||
@ -112,6 +131,14 @@ class _BoundPromptSequence:
|
||||
_text: Optional[str]
|
||||
_token_ids: Optional[list[int]]
|
||||
|
||||
@staticmethod
|
||||
def from_seq(tokenizer: AnyTokenizer, seq: _PromptSeq):
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=tokenizer,
|
||||
_text=seq if isinstance(seq, str) else None,
|
||||
_token_ids=seq if isinstance(seq, list) else None,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._text is None and self._token_ids is None:
|
||||
raise ValueError("At least one of 'text' and 'token_ids' must be "
|
||||
@ -134,6 +161,12 @@ class _BoundPromptSequence:
|
||||
return self._token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptReplacementGroup:
|
||||
full: _BoundPromptSequence
|
||||
features: _BoundPromptSequence
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundPromptReplacement:
|
||||
"""
|
||||
@ -145,24 +178,18 @@ class BoundPromptReplacement:
|
||||
modality: str
|
||||
|
||||
_target: _PromptSeq
|
||||
_replacement: Union[Callable[[int], _PromptSeq],
|
||||
_PromptSeq] = field(repr=False)
|
||||
_replacement: Union[Callable[[int], _PromptRepl],
|
||||
_PromptRepl] = field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._replacement_cache = dict[int, _BoundPromptSequence]()
|
||||
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
|
||||
|
||||
@property
|
||||
def target(self) -> _BoundPromptSequence:
|
||||
"""The token sequence (or text) to find and replace."""
|
||||
target = self._target
|
||||
return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
|
||||
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=self.tokenizer,
|
||||
_text=target if isinstance(target, str) else None,
|
||||
_token_ids=target if isinstance(target, list) else None,
|
||||
)
|
||||
|
||||
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
|
||||
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the replacement token sequence (or text).
|
||||
@ -177,10 +204,16 @@ class BoundPromptReplacement:
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
bound_replacement = _BoundPromptSequence(
|
||||
tokenizer=self.tokenizer,
|
||||
_text=replacement if isinstance(replacement, str) else None,
|
||||
_token_ids=replacement if isinstance(replacement, list) else None,
|
||||
if not isinstance(replacement, PromptReplacementDetails):
|
||||
replacement = PromptReplacementDetails.from_seq(replacement)
|
||||
|
||||
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
replacement.full)
|
||||
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
replacement.features)
|
||||
bound_replacement = _BoundPromptReplacementGroup(
|
||||
full=bound_full,
|
||||
features=bound_features,
|
||||
)
|
||||
|
||||
if cache_key is not None:
|
||||
@ -197,7 +230,7 @@ class _TokenMatch(NamedTuple):
|
||||
def iter_token_matches(
|
||||
token_ids: list[int],
|
||||
match_ids: list[int],
|
||||
) -> Iterable[_TokenMatch]:
|
||||
) -> Generator[_TokenMatch]:
|
||||
"""
|
||||
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
|
||||
|
||||
@ -272,15 +305,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaceholderInfo:
|
||||
class PlaceholderFeaturesInfo:
|
||||
modality: str
|
||||
item_idx: int
|
||||
start_idx: int
|
||||
replacement: list[int]
|
||||
tokens: list[int]
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self.replacement)
|
||||
return len(self.tokens)
|
||||
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
return PlaceholderRange(
|
||||
@ -362,10 +395,10 @@ def _replace_matches(
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
repl_seq = replacement.text
|
||||
repl_seq = replacement.full.text
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
else:
|
||||
repl_seq = replacement.token_ids
|
||||
repl_seq = replacement.full.token_ids
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
|
||||
prev_end_idx = end_idx
|
||||
@ -408,7 +441,7 @@ def _iter_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Iterable[PlaceholderInfo]:
|
||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||
|
||||
@ -432,23 +465,33 @@ def _iter_placeholders(
|
||||
|
||||
for repl_info in modality_repls:
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
repl_tokens = replacement.token_ids
|
||||
repl_len = len(repl_tokens)
|
||||
end_idx = start_idx + repl_len
|
||||
repl_tokens_full = replacement.full.token_ids
|
||||
repl_len_full = len(repl_tokens_full)
|
||||
end_idx_full = start_idx + repl_len_full
|
||||
|
||||
if repl_len == 0 or end_idx > prompt_len:
|
||||
if repl_len_full == 0 or end_idx_full > prompt_len:
|
||||
continue
|
||||
|
||||
if prompt[start_idx:end_idx] == repl_tokens:
|
||||
yield PlaceholderInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx,
|
||||
replacement=repl_tokens,
|
||||
)
|
||||
if prompt[start_idx:end_idx_full] == repl_tokens_full:
|
||||
repl_tokens_feat = replacement.features.token_ids
|
||||
|
||||
try:
|
||||
match = next(
|
||||
iter_token_matches(repl_tokens_full,
|
||||
repl_tokens_feat))
|
||||
yield PlaceholderFeaturesInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx + match.start_idx,
|
||||
tokens=repl_tokens_feat,
|
||||
)
|
||||
except StopIteration:
|
||||
raise AssertionError(
|
||||
f"{repl_tokens_feat=} should be a "
|
||||
f"subsequence of {repl_tokens_full=}") from None
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = end_idx
|
||||
start_idx = end_idx_full
|
||||
item_idx_by_modality[modality] += 1
|
||||
found = True
|
||||
break
|
||||
@ -464,7 +507,7 @@ def find_mm_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderInfo]]:
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
@ -679,7 +722,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderInfo]]:
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
@ -948,7 +991,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
token_ids: list[int],
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
mm_token_matches = {
|
||||
@ -1037,7 +1080,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def _validate_mm_placeholders(
|
||||
self,
|
||||
mm_placeholders: Mapping[str, list[PlaceholderInfo]],
|
||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
*,
|
||||
allow_missing: bool = False,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user