mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 02:44:28 +08:00
[VLM] Keep track of whether prompt replacements have been applied (#13215)
This commit is contained in:
parent
556ef7f714
commit
4da1f667e9
@ -484,6 +484,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
|
||||
@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor(
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
# Before/after https://github.com/huggingface/transformers/pull/35122
|
||||
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
|
||||
if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
@ -819,7 +819,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
|
||||
@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
video_token = processor.video_token
|
||||
|
||||
# LLaVA-OneVision processor doesn't support multiple videos
|
||||
# with different sizes when converting back to tensors
|
||||
text_image_outputs = super()._call_hf_processor(
|
||||
# So, we process each component separately
|
||||
# NOTE: No prompt replacement is applied in this case
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
video_token = processor.video_token
|
||||
|
||||
text_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_data={},
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
images = mm_data.pop("images", [])
|
||||
assert isinstance(images, list)
|
||||
if images:
|
||||
processor_outputs = super()._call_hf_processor(
|
||||
prompt=image_token * len(images),
|
||||
mm_data={"images": images},
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
image_outputs = {
|
||||
k: v
|
||||
for k, v in processor_outputs.items()
|
||||
if k in ("pixel_values", "image_sizes")
|
||||
}
|
||||
else:
|
||||
image_outputs = {}
|
||||
|
||||
pixel_values_videos = []
|
||||
for video in videos:
|
||||
item_processor_data = dict(prompt=video_token, videos=video)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=item_processor_data,
|
||||
prompt=video_token,
|
||||
mm_data={"videos": video},
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
pixel_values_videos.append(
|
||||
item_outputs.pop("pixel_values_videos")[0])
|
||||
pixel_values_videos.append(item_outputs["pixel_values_videos"][0])
|
||||
|
||||
video_outputs = {"pixel_values_videos": pixel_values_videos}
|
||||
|
||||
combined_outputs = dict(
|
||||
**text_image_outputs,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
text_outputs,
|
||||
**image_outputs,
|
||||
**video_outputs,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
base_result = super()._hf_processor_applies_repl(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
return base_result and mm_items.get_count("video", strict=False) == 0
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@ -27,8 +27,8 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
from torch import nn
|
||||
from transformers import BatchFeature
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
||||
@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (ModalityData, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser,
|
||||
VideoItem)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import PromptReplacement
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo)
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
@ -103,28 +101,49 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
MiniCPMOAudioEmbeddingInputs]
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems):
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
|
||||
|
||||
def __init__(self, data: Dict) -> None:
|
||||
super().__init__(data, "audio")
|
||||
audio_embeds = self.data.get("audio_embeds", None)
|
||||
if audio_embeds is None:
|
||||
raise ValueError("Incorrect type of video_embeds",
|
||||
"Got type: None")
|
||||
self.data["audio_embeds"] = audio_embeds
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
)
|
||||
|
||||
def get(self, index: int) -> object:
|
||||
return self.data["audio_embeds"][index]
|
||||
|
||||
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
fields_config: Mapping[str, MultiModalFieldConfig],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
data,
|
||||
modality="image",
|
||||
fields_config=fields_config,
|
||||
required_fields={"audio_embeds"},
|
||||
)
|
||||
|
||||
|
||||
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMOAudioEmbeddingItems(data)
|
||||
return MiniCPMOAudioEmbeddingItems(
|
||||
data,
|
||||
fields_config=_minicpmo_field_config(data),
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
@ -167,6 +186,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
def get_max_audio_chunks_with_most_features(self) -> int:
|
||||
return 30
|
||||
|
||||
def get_max_audio_tokens(self) -> int:
|
||||
return self.get_max_audio_tokens_per_chunk(
|
||||
) * self.get_max_audio_chunks_with_most_features()
|
||||
|
||||
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
||||
sampling_rate = self.get_default_audio_sampling_rate()
|
||||
# exclude <audio> </audio>
|
||||
@ -194,7 +217,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
return num_frames
|
||||
|
||||
|
||||
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
|
||||
class MiniCPMODummyInputsBuilder(
|
||||
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self, seq_len: int, mm_counts: Mapping[str,
|
||||
@ -222,8 +246,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
|
||||
|
||||
|
||||
class MiniCPMOMultiModalProcessor(
|
||||
MiniCPMVMultiModalProcessor,
|
||||
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
|
||||
MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MiniCPMOMultiModalDataParser(
|
||||
@ -369,21 +392,10 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices))
|
||||
return _minicpmo_field_config(hf_inputs)
|
||||
|
||||
|
||||
class MultiModalProjector(nn.Module):
|
||||
@ -406,7 +418,7 @@ class MultiModalProjector(nn.Module):
|
||||
|
||||
class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[
|
||||
|
||||
@ -35,6 +35,7 @@ import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, PlaceholderRange)
|
||||
from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser, VideoItem)
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
|
||||
ModalityData, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser,
|
||||
VideoItem)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
||||
MiniCPMVImageEmbeddingInputs]
|
||||
|
||||
|
||||
class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
def __init__(self, data: Dict, modality: str) -> None:
|
||||
super().__init__(data, modality)
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data[f"{self.modality}_embeds"])
|
||||
|
||||
def get(self, index: int) -> Dict[str, torch.Tensor]:
|
||||
out = {}
|
||||
for k, v in self.data.items():
|
||||
out[k] = v[index]
|
||||
return out
|
||||
|
||||
|
||||
class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems):
|
||||
|
||||
def __init__(self, data: Dict) -> None:
|
||||
super().__init__(data, "image")
|
||||
image_embeds = self.data.get("image_embeds", None)
|
||||
image_sizes = self.data.get("image_sizes", None)
|
||||
if image_embeds is None:
|
||||
raise ValueError("In correct type of image_embeds",
|
||||
"Got type: None")
|
||||
if not isinstance(image_embeds[0], torch.Tensor):
|
||||
raise ValueError("In correct type of image_embeds",
|
||||
f"Got type: {type(image_embeds[0])}")
|
||||
if image_sizes is None:
|
||||
raise ValueError(
|
||||
"In correct type of image_sizes", "Got type: None."
|
||||
"If you're using `image_size_list`, "
|
||||
"please rename it to `image_sizes`")
|
||||
if len(image_embeds[0].shape) == 2:
|
||||
image_embeds = [image_embeds]
|
||||
image_sizes = [image_sizes]
|
||||
self.data["image_embeds"] = image_embeds
|
||||
self.data["image_sizes"] = image_sizes
|
||||
|
||||
def get_image_size(self, index: int) -> ImageSize:
|
||||
image_size = self.data["image_sizes"][index]
|
||||
return ImageSize(width=image_size[0], height=image_size[1])
|
||||
|
||||
|
||||
class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems):
|
||||
|
||||
def __init__(self, data: Dict) -> None:
|
||||
super().__init__(data, "video")
|
||||
video_embeds = self.data.get("video_embeds", None)
|
||||
image_sizes = self.data.get("image_sizes", None)
|
||||
num_frames = self.data.get("num_frames", None)
|
||||
if video_embeds is None:
|
||||
raise ValueError("In correct type of video_embeds",
|
||||
"Got type: None")
|
||||
if not isinstance(video_embeds[0], torch.Tensor):
|
||||
raise ValueError("In correct type of video_embeds",
|
||||
f"Got type: {type(video_embeds[0])}")
|
||||
if image_sizes is None:
|
||||
raise ValueError(
|
||||
"In correct type of image_sizes", "Got type: None."
|
||||
"If you're using `image_size_list`, "
|
||||
"please rename it to `image_sizes`")
|
||||
if num_frames is None:
|
||||
raise ValueError("In correct type of numframes", "Got type: None")
|
||||
if len(video_embeds[0].shape) == 2:
|
||||
video_embeds = [video_embeds]
|
||||
image_sizes = [image_sizes]
|
||||
num_frames = [num_frames]
|
||||
self.data["video_embeds"] = video_embeds
|
||||
self.data["image_sizes"] = image_sizes
|
||||
self.data["num_frames"] = num_frames
|
||||
|
||||
def get_frame_size(self, index: int) -> ImageSize:
|
||||
frame_size = self.data["image_sizes"][index]
|
||||
return ImageSize(width=frame_size[0], height=frame_size[1])
|
||||
|
||||
def get_num_frames(self, index: int) -> int:
|
||||
return self.data["num_frames"][index]
|
||||
|
||||
|
||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
|
||||
@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
return tuple(int(x) for x in version_str.split("."))
|
||||
|
||||
|
||||
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
|
||||
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_num_slices=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_num_slices=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
|
||||
class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
fields_config: Mapping[str, MultiModalFieldConfig],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
data,
|
||||
modality="image",
|
||||
fields_config=fields_config,
|
||||
required_fields={"image_embeds", "image_sizes"},
|
||||
)
|
||||
|
||||
def get_image_size(self, index: int) -> ImageSize:
|
||||
image_size = self.get(index)["image_sizes"].tolist()
|
||||
return ImageSize(width=image_size[0], height=image_size[1])
|
||||
|
||||
|
||||
class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
fields_config: Mapping[str, MultiModalFieldConfig],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
data,
|
||||
modality="video",
|
||||
fields_config=fields_config,
|
||||
required_fields={"video_embeds", "video_image_sizes"},
|
||||
)
|
||||
|
||||
def get_frame_size(self, index: int) -> ImageSize:
|
||||
frame_size = self.get(index)["video_image_sizes"].tolist()
|
||||
return ImageSize(width=frame_size[0], height=frame_size[1])
|
||||
|
||||
def get_num_frames(self, index: int) -> int:
|
||||
return len(self.get(index)["video_image_sizes"])
|
||||
|
||||
|
||||
class MiniCPMVMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def _parse_image_data(
|
||||
@ -318,7 +298,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMVImageEmbeddingItems(data)
|
||||
return MiniCPMVImageEmbeddingItems(
|
||||
data,
|
||||
fields_config=_minicpmv_field_config(data),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
def _parse_video_data(
|
||||
@ -326,7 +310,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMVVideoEmbeddingItems(data)
|
||||
return MiniCPMVVideoEmbeddingItems(
|
||||
data,
|
||||
fields_config=_minicpmv_field_config(data),
|
||||
)
|
||||
|
||||
return super()._parse_video_data(data)
|
||||
|
||||
|
||||
@ -392,10 +380,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
return self.get_max_video_frame_tokens(
|
||||
) * self.get_num_frames_with_most_features(seq_len)
|
||||
|
||||
def get_max_audio_tokens(self) -> int:
|
||||
return self.get_max_audio_tokens_per_chunk(
|
||||
) * self.get_max_audio_chunks_with_most_features()
|
||||
|
||||
def get_slice_query_num(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
query_num = getattr(hf_config, "query_num", 64)
|
||||
@ -476,8 +460,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
return ImageSize(width=image_size, height=image_size * num_slices)
|
||||
|
||||
|
||||
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
|
||||
):
|
||||
_I = TypeVar("_I",
|
||||
bound=MiniCPMVProcessingInfo,
|
||||
default=MiniCPMVProcessingInfo)
|
||||
|
||||
|
||||
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
@ -514,8 +502,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
|
||||
mm_data=mm_data)
|
||||
|
||||
|
||||
class MiniCPMVMultiModalProcessor(
|
||||
BaseMultiModalProcessor[MiniCPMVProcessingInfo]):
|
||||
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MiniCPMVMultiModalDataParser()
|
||||
@ -675,7 +662,7 @@ class MiniCPMVMultiModalProcessor(
|
||||
self.info.get_video_max_slice_num()
|
||||
) * inputs[modality]["num_frames"][index]
|
||||
else:
|
||||
raise ValueError(f"UnExpected modality: {modality}")
|
||||
raise ValueError(f"Unexpected modality: {modality}")
|
||||
|
||||
def check_mm_inputs(self, inputs: Dict[str, object],
|
||||
matches: List[str]) -> None:
|
||||
@ -700,7 +687,7 @@ class MiniCPMVMultiModalProcessor(
|
||||
inputs["video"]["video_image_sizes"][index],
|
||||
inputs["video"]["num_frames"][index])
|
||||
else:
|
||||
raise ValueError(f"UnExpected modality: {modality}")
|
||||
raise ValueError(f"Unexpected modality: {modality}")
|
||||
|
||||
def call_base_hf_processor(
|
||||
self,
|
||||
@ -742,6 +729,14 @@ class MiniCPMVMultiModalProcessor(
|
||||
}
|
||||
}
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
@ -770,28 +765,10 @@ class MiniCPMVMultiModalProcessor(
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
|
||||
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
|
||||
|
||||
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_num_slices=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_num_slices=MultiModalFieldConfig.batched("video"))
|
||||
return _minicpmv_field_config(hf_inputs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -243,16 +243,6 @@ class Qwen2AudioMultiModalProcessor(
|
||||
)
|
||||
]
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
# Qwen2-Audio processor will start inserting placeholder tokens
|
||||
# in an upcoming release:
|
||||
# https://github.com/huggingface/transformers/pull/35534
|
||||
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
|
||||
# has already performed processing for multi-audio input when the input
|
||||
# audios are short (the corresponding placeholders may take up fewer
|
||||
# tokens than the number of audio items)
|
||||
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2AudioMultiModalProcessor,
|
||||
|
||||
@ -58,8 +58,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalFieldConfig, MultiModalKwargs,
|
||||
VideoItem)
|
||||
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser)
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
@ -657,49 +658,25 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
|
||||
def __init__(self, data: dict, modality: str) -> None:
|
||||
super().__init__(data, modality)
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
|
||||
grid_thw = data[f"{modality}_grid_thw"]
|
||||
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
self._slices = [
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
for i in range(len(grid_thw))
|
||||
]
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data[f"{self.modality}_grid_thw"])
|
||||
|
||||
def get(self, index: int) -> dict[str, torch.Tensor]:
|
||||
out = {}
|
||||
for k, v in self.data.items():
|
||||
if v != f"{self.modality}_grid_thw":
|
||||
v = v[self._slices[index]]
|
||||
|
||||
out[k] = v
|
||||
|
||||
return out
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "video")
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
|
||||
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
@ -709,7 +686,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2VLEmbeddingItems(data, modality="image")
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
fields_config=_qwen2vl_field_config(data),
|
||||
required_fields={"image_embeds", "image_grid_thw"},
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
@ -718,7 +700,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2VLEmbeddingItems(data, modality="video")
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="video",
|
||||
fields_config=_qwen2vl_field_config(data),
|
||||
required_fields={"video_embeds", "video_grid_thw"},
|
||||
)
|
||||
|
||||
return super()._parse_video_data(data)
|
||||
|
||||
@ -999,24 +986,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
return _qwen2vl_field_config(hf_inputs)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
|
||||
|
||||
@ -520,10 +520,7 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
return _get_tokenizer_without_image_pad(tokenizer)
|
||||
|
||||
def get_hf_processor(self) -> QwenVLProcessor:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return QwenVLProcessor(self.get_hf_config(), tokenizer)
|
||||
return QwenVLProcessor(self.get_hf_config(), self.get_tokenizer())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
@ -605,6 +602,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
|
||||
@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .audio import resample_audio
|
||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
ImageItem, ModalityData, MultiModalDataDict, VideoItem)
|
||||
ImageItem, ModalityData, MultiModalDataDict,
|
||||
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
@ -111,6 +113,60 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
||||
Mapping[str, torch.Tensor]]):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
fields_config: Mapping[str, MultiModalFieldConfig],
|
||||
required_fields: set[str],
|
||||
) -> None:
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}")
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargs.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return self._kwargs.get_item_count(self.modality)
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return {
|
||||
k: v.data
|
||||
for k, v in self._kwargs.get_item(self.modality, index).items()
|
||||
}
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
||||
|
||||
@ -23,7 +23,8 @@ from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
@ -830,15 +831,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
"""
|
||||
Return whether the HF processor applies prompt replacements.
|
||||
|
||||
For most HF processors, this should be :code:`True` when multi-modal
|
||||
data items are passed, but :code:`False` when multi-modal embeddings
|
||||
are passed.
|
||||
"""
|
||||
return not any(
|
||||
isinstance(items, (EmbeddingItems, DictEmbeddingItems))
|
||||
for items in mm_items.values())
|
||||
|
||||
def _apply_hf_processor_text_mm(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs]:
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
"""
|
||||
Apply the HF processor on the prompt text and multi-modal data
|
||||
together.
|
||||
|
||||
In addition, return whether prompt replacements have been applied.
|
||||
"""
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
|
||||
@ -856,7 +876,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
is_repl_applied = self._hf_processor_applies_repl(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_kwargs, is_repl_applied
|
||||
|
||||
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
|
||||
"""
|
||||
@ -866,7 +892,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
correspond to each other, we create dummy multi-modal items
|
||||
to go along with the text.
|
||||
"""
|
||||
prompt_ids, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_ids, _, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=MultiModalDataItems({}),
|
||||
hf_processor_mm_kwargs={},
|
||||
@ -908,7 +934,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_counts,
|
||||
)
|
||||
|
||||
_, mm_kwargs = self._apply_hf_processor_text_mm(
|
||||
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=dummy_inputs.prompt_text,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@ -923,13 +949,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
enable_hf_prompt_replacement: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs]:
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
"""
|
||||
Apply the HF processor on the prompt text and multi-modal data.
|
||||
|
||||
In addition, return whether prompt replacements have been applied
|
||||
(for most HF processors, this should be :code:`True`).
|
||||
|
||||
Note:
|
||||
If :code:`enable_hf_prompt_replacement=False`, the prompt should
|
||||
correspond to the multi-modal items.
|
||||
If :code:`enable_hf_prompt_replacement=False`, we use HF processor
|
||||
to perform prompt replacement if available; HF processor requires
|
||||
that the prompt corresponds to multi-modal items.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
if enable_hf_prompt_replacement:
|
||||
@ -943,19 +973,19 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
else:
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
||||
|
||||
mm_missing_kwargs = self._apply_hf_processor_mm_only(
|
||||
mm_kwargs = self._apply_hf_processor_mm_only(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_missing_kwargs
|
||||
return prompt_ids, mm_kwargs, False
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs]:
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
caching the results and reusing cached results.
|
||||
@ -992,8 +1022,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_missing_data_items = self._to_mm_items(mm_missing_data)
|
||||
|
||||
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
||||
# so we need to pass `enable_hf_prompt_replacement=False`
|
||||
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
|
||||
# so we can't apply prompt replacements until the new multimodal
|
||||
# items are combined with the cached multimodal items
|
||||
(
|
||||
prompt_ids,
|
||||
mm_missing_kwargs,
|
||||
is_repl_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@ -1036,7 +1071,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
return prompt_ids, mm_kwargs, is_repl_applied
|
||||
|
||||
def _bind_and_group_repls(
|
||||
self,
|
||||
@ -1047,18 +1082,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
"""
|
||||
A flag which can be overridden so that
|
||||
:meth:`_apply_prompt_replacements` is always called even if we
|
||||
detect that HF has performed processing via
|
||||
:meth:`_find_placeholders_by_modality`.
|
||||
|
||||
This is useful in cases where :meth:`_find_placeholders_by_modality`
|
||||
cannot be reliably used to detect whether HF has performed processing.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
@ -1155,29 +1178,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
self,
|
||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
*,
|
||||
allow_missing: bool = False,
|
||||
) -> Mapping[str, int]:
|
||||
missing_repl_counts = dict[str, int]()
|
||||
|
||||
) -> None:
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
placeholders = mm_placeholders.get(modality, [])
|
||||
|
||||
if len(placeholders) != item_count and not allow_missing:
|
||||
if len(placeholders) != item_count:
|
||||
raise RuntimeError(
|
||||
f"Expected there to be {item_count} prompt replacements "
|
||||
f"corresponding to {item_count} {modality} items, but only "
|
||||
f"found {len(placeholders)} prompt replacements! Either "
|
||||
"the prompt text has missing/incorrect tokens for "
|
||||
f"corresponding to {item_count} {modality} items, but "
|
||||
f"instead found {len(placeholders)} prompt replacements! "
|
||||
"Either the prompt text has missing/incorrect tokens for "
|
||||
"multi-modal inputs, or there is a problem with your "
|
||||
"implementation of merged multi-modal processor for this "
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_replacements`).")
|
||||
|
||||
missing_repl_counts[modality] = item_count - len(placeholders)
|
||||
|
||||
return missing_repl_counts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
@ -1217,7 +1232,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
else:
|
||||
mm_hashes = None
|
||||
|
||||
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_repl_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
@ -1233,52 +1252,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
|
||||
hf_mm_placeholders = self._find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
if self._always_apply_prompt_replacements():
|
||||
mm_missing_repl_counts = mm_item_counts
|
||||
mm_missing_repls = dict(mm_prompt_repls)
|
||||
else:
|
||||
mm_missing_repl_counts = self._validate_mm_placeholders(
|
||||
hf_mm_placeholders,
|
||||
if is_repl_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
allow_missing=True,
|
||||
)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
|
||||
for modality, missing_repl_count in mm_missing_repl_counts.items():
|
||||
if missing_repl_count == 0:
|
||||
mm_missing_repls[modality] = []
|
||||
elif missing_repl_count == mm_item_counts.get(modality, 0):
|
||||
mm_missing_repls[modality] = mm_prompt_repls[modality]
|
||||
else:
|
||||
raise ValueError("Partial prompt replacement within "
|
||||
f"{modality=} is not supported")
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||
mm_placeholders = hf_mm_placeholders
|
||||
else:
|
||||
(
|
||||
prompt_ids,
|
||||
prompt,
|
||||
missing_mm_placeholders,
|
||||
mm_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
prompt_ids,
|
||||
mm_missing_repls,
|
||||
mm_missing_repl_counts,
|
||||
mm_prompt_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
|
||||
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user