[VLM] Keep track of whether prompt replacements have been applied (#13215)

This commit is contained in:
Cyrus Leung 2025-02-14 20:20:46 +08:00 committed by GitHub
parent 556ef7f714
commit 4da1f667e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 372 additions and 328 deletions

View File

@ -484,6 +484,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,

View File

@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
# Before/after https://github.com/huggingface/transformers/pull/35122
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
images = mm_data["images"]
assert isinstance(images, list)
@ -819,7 +819,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids,
mm_item_counts,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = {

View File

@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor(
mm_kwargs=mm_kwargs,
)
processor = self.info.get_hf_processor()
video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
# So, we process each component separately
# NOTE: No prompt replacement is applied in this case
processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token
text_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_data={},
mm_kwargs=mm_kwargs,
)
images = mm_data.pop("images", [])
assert isinstance(images, list)
if images:
processor_outputs = super()._call_hf_processor(
prompt=image_token * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs,
)
image_outputs = {
k: v
for k, v in processor_outputs.items()
if k in ("pixel_values", "image_sizes")
}
else:
image_outputs = {}
pixel_values_videos = []
for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
prompt=video_token,
mm_data={"videos": video},
mm_kwargs=mm_kwargs,
)
pixel_values_videos.append(
item_outputs.pop("pixel_values_videos")[0])
pixel_values_videos.append(item_outputs["pixel_values_videos"][0])
video_outputs = {"pixel_values_videos": pixel_values_videos}
combined_outputs = dict(
**text_image_outputs,
pixel_values_videos=pixel_values_videos,
text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,

View File

@ -27,8 +27,8 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import torch
import torch.types
from torch import nn
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ModalityData, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser,
VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo)
MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix
CPU_DEVICE = torch.device("cpu")
MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
@ -103,28 +101,49 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs]
class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems):
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
def __init__(self, data: Dict) -> None:
super().__init__(data, "audio")
audio_embeds = self.data.get("audio_embeds", None)
if audio_embeds is None:
raise ValueError("Incorrect type of video_embeds",
"Got type: None")
self.data["audio_embeds"] = audio_embeds
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
)
def get(self, index: int) -> object:
return self.data["audio_embeds"][index]
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"},
)
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(data)
return MiniCPMOAudioEmbeddingItems(
data,
fields_config=_minicpmo_field_config(data),
)
return super()._parse_audio_data(data)
@ -167,6 +186,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_max_audio_chunks_with_most_features(self) -> int:
return 30
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio>
@ -194,7 +217,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
return num_frames
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs(
self, seq_len: int, mm_counts: Mapping[str,
@ -222,8 +246,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class MiniCPMOMultiModalProcessor(
MiniCPMVMultiModalProcessor,
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMOMultiModalDataParser(
@ -369,21 +392,10 @@ class MiniCPMOMultiModalProcessor(
def _get_mm_fields_config(
self,
hf_inputs,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
return _minicpmo_field_config(hf_inputs)
class MultiModalProjector(nn.Module):
@ -406,7 +418,7 @@ class MultiModalProjector(nn.Module):
class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
def __init__(self, config: WhisperConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[

View File

@ -35,6 +35,7 @@ import torch.types
from PIL import Image
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, PlaceholderRange)
from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser, VideoItem)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
ModalityData, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser,
VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def __init__(self, data: Dict, modality: str) -> None:
super().__init__(data, modality)
def get_processor_data(self) -> Mapping[str, object]:
return self.data
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
def get_count(self) -> int:
return len(self.data[f"{self.modality}_embeds"])
def get(self, index: int) -> Dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
out[k] = v[index]
return out
class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "image")
image_embeds = self.data.get("image_embeds", None)
image_sizes = self.data.get("image_sizes", None)
if image_embeds is None:
raise ValueError("In correct type of image_embeds",
"Got type: None")
if not isinstance(image_embeds[0], torch.Tensor):
raise ValueError("In correct type of image_embeds",
f"Got type: {type(image_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if len(image_embeds[0].shape) == 2:
image_embeds = [image_embeds]
image_sizes = [image_sizes]
self.data["image_embeds"] = image_embeds
self.data["image_sizes"] = image_sizes
def get_image_size(self, index: int) -> ImageSize:
image_size = self.data["image_sizes"][index]
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "video")
video_embeds = self.data.get("video_embeds", None)
image_sizes = self.data.get("image_sizes", None)
num_frames = self.data.get("num_frames", None)
if video_embeds is None:
raise ValueError("In correct type of video_embeds",
"Got type: None")
if not isinstance(video_embeds[0], torch.Tensor):
raise ValueError("In correct type of video_embeds",
f"Got type: {type(video_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if num_frames is None:
raise ValueError("In correct type of numframes", "Got type: None")
if len(video_embeds[0].shape) == 2:
video_embeds = [video_embeds]
image_sizes = [image_sizes]
num_frames = [num_frames]
self.data["video_embeds"] = video_embeds
self.data["image_sizes"] = image_sizes
self.data["num_frames"] = num_frames
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.data["image_sizes"][index]
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return self.data["num_frames"][index]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split("."))
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
)
class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"image_embeds", "image_sizes"},
)
def get_image_size(self, index: int) -> ImageSize:
image_size = self.get(index)["image_sizes"].tolist()
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="video",
fields_config=fields_config,
required_fields={"video_embeds", "video_image_sizes"},
)
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.get(index)["video_image_sizes"].tolist()
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return len(self.get(index)["video_image_sizes"])
class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
@ -318,7 +298,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(data)
return MiniCPMVImageEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
)
return super()._parse_image_data(data)
def _parse_video_data(
@ -326,7 +310,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(data)
return MiniCPMVVideoEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
)
return super()._parse_video_data(data)
@ -392,10 +380,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len)
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_slice_query_num(self) -> int:
hf_config = self.get_hf_config()
query_num = getattr(hf_config, "query_num", 64)
@ -476,8 +460,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return ImageSize(width=image_size, height=image_size * num_slices)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
):
_I = TypeVar("_I",
bound=MiniCPMVProcessingInfo,
default=MiniCPMVProcessingInfo)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
@ -514,8 +502,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
mm_data=mm_data)
class MiniCPMVMultiModalProcessor(
BaseMultiModalProcessor[MiniCPMVProcessingInfo]):
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser()
@ -675,7 +662,7 @@ class MiniCPMVMultiModalProcessor(
self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index]
else:
raise ValueError(f"UnExpected modality: {modality}")
raise ValueError(f"Unexpected modality: {modality}")
def check_mm_inputs(self, inputs: Dict[str, object],
matches: List[str]) -> None:
@ -700,7 +687,7 @@ class MiniCPMVMultiModalProcessor(
inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index])
else:
raise ValueError(f"UnExpected modality: {modality}")
raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor(
self,
@ -742,6 +729,14 @@ class MiniCPMVMultiModalProcessor(
}
}
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_replacements(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
@ -770,28 +765,10 @@ class MiniCPMVMultiModalProcessor(
def _get_mm_fields_config(
self,
hf_inputs,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
return _minicpmv_field_config(hf_inputs)
def apply(
self,

View File

@ -243,16 +243,6 @@ class Qwen2AudioMultiModalProcessor(
)
]
def _always_apply_prompt_replacements(self) -> bool:
# Qwen2-Audio processor will start inserting placeholder tokens
# in an upcoming release:
# https://github.com/huggingface/transformers/pull/35534
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,

View File

@ -58,8 +58,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@ -657,49 +658,25 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data, modality)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
@ -709,7 +686,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="image")
return DictEmbeddingItems(
data,
modality="image",
fields_config=_qwen2vl_field_config(data),
required_fields={"image_embeds", "image_grid_thw"},
)
return super()._parse_image_data(data)
@ -718,7 +700,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="video")
return DictEmbeddingItems(
data,
modality="video",
fields_config=_qwen2vl_field_config(data),
required_fields={"video_embeds", "video_grid_thw"},
)
return super()._parse_video_data(data)
@ -999,24 +986,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
return _qwen2vl_field_config(hf_inputs)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,

View File

@ -520,10 +520,7 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
return _get_tokenizer_without_image_pad(tokenizer)
def get_hf_processor(self) -> QwenVLProcessor:
tokenizer = self.ctx.tokenizer
assert isinstance(tokenizer, PreTrainedTokenizer)
return QwenVLProcessor(self.get_hf_config(), tokenizer)
return QwenVLProcessor(self.get_hf_config(), self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@ -605,6 +602,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs,
)
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,

View File

@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import numpy as np
import torch
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, VideoItem)
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
@ -111,6 +113,60 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return len(self.get(item_idx))
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
fields_config: Mapping[str, MultiModalFieldConfig],
required_fields: set[str],
) -> None:
super().__init__(data, modality)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None:

View File

@ -23,7 +23,8 @@ from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder
@ -830,15 +831,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs,
)
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
"""
Return whether the HF processor applies prompt replacements.
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
are passed.
"""
return not any(
isinstance(items, (EmbeddingItems, DictEmbeddingItems))
for items in mm_items.values())
def _apply_hf_processor_text_mm(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt replacements have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
@ -856,7 +876,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
return prompt_ids, mm_kwargs
is_repl_applied = self._hf_processor_applies_repl(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return prompt_ids, mm_kwargs, is_repl_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
"""
@ -866,7 +892,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
correspond to each other, we create dummy multi-modal items
to go along with the text.
"""
prompt_ids, _ = self._apply_hf_processor_text_mm(
prompt_ids, _, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt_text,
mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={},
@ -908,7 +934,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_counts,
)
_, mm_kwargs = self._apply_hf_processor_text_mm(
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
prompt_text=dummy_inputs.prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -923,13 +949,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
*,
enable_hf_prompt_replacement: bool,
) -> tuple[list[int], MultiModalKwargs]:
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt replacements have been applied
(for most HF processors, this should be :code:`True`).
Note:
If :code:`enable_hf_prompt_replacement=False`, the prompt should
correspond to the multi-modal items.
If :code:`enable_hf_prompt_replacement=False`, we use HF processor
to perform prompt replacement if available; HF processor requires
that the prompt corresponds to multi-modal items.
"""
if isinstance(prompt, str):
if enable_hf_prompt_replacement:
@ -943,19 +973,19 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_missing_kwargs = self._apply_hf_processor_mm_only(
mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return prompt_ids, mm_missing_kwargs
return prompt_ids, mm_kwargs, False
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
@ -992,8 +1022,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items = self._to_mm_items(mm_missing_data)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we need to pass `enable_hf_prompt_replacement=False`
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
# so we can't apply prompt replacements until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_kwargs,
is_repl_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -1036,7 +1071,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
return prompt_ids, mm_kwargs
return prompt_ids, mm_kwargs, is_repl_applied
def _bind_and_group_repls(
self,
@ -1047,18 +1082,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
def _always_apply_prompt_replacements(self) -> bool:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders_by_modality`
cannot be reliably used to detect whether HF has performed processing.
"""
return False
def _apply_prompt_replacements(
self,
token_ids: list[int],
@ -1155,29 +1178,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
) -> Mapping[str, int]:
missing_repl_counts = dict[str, int]()
) -> None:
for modality, item_count in mm_item_counts.items():
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count and not allow_missing:
if len(placeholders) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} prompt replacements "
f"corresponding to {item_count} {modality} items, but only "
f"found {len(placeholders)} prompt replacements! Either "
"the prompt text has missing/incorrect tokens for "
f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt replacements! "
"Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).")
missing_repl_counts[modality] = item_count - len(placeholders)
return missing_repl_counts
def apply(
self,
prompt: Union[str, list[int]],
@ -1217,7 +1232,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else:
mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
(
prompt_ids,
mm_kwargs,
is_repl_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
@ -1233,52 +1252,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
hf_mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
mm_item_counts,
)
if self._always_apply_prompt_replacements():
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
if is_repl_applied:
mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
mm_item_counts,
allow_missing=True,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
elif missing_repl_count == mm_item_counts.get(modality, 0):
mm_missing_repls[modality] = mm_prompt_repls[modality]
else:
raise ValueError("Partial prompt replacement within "
f"{modality=} is not supported")
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else:
(
prompt_ids,
prompt,
missing_mm_placeholders,
mm_placeholders,
) = self._apply_prompt_replacements(
prompt_ids,
mm_missing_repls,
mm_missing_repl_counts,
mm_prompt_repls,
mm_item_counts,
)
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]