diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 3849bd37a8290..1cc2562759d47 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -169,7 +169,6 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=2048, max_num_seqs=2, - # Default is False; setting it to True is not supported in V1 yet mm_processor_kwargs={"do_pan_and_scan": True}, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 3a17e5bab0931..98a739169d702 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -91,8 +91,6 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: model=model_name, max_model_len=8192, max_num_seqs=2, - # Default is False; setting it to True is not supported in V1 yet - mm_processor_kwargs={"do_pan_and_scan": True}, limit_mm_per_prompt={"image": len(image_urls)}, ) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 62e55d64cf2ca..8db2bfb901bf3 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -183,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): image_width: int, image_height: int, processor: Optional[Gemma3Processor], - ) -> PromptUpdateDetails: + ) -> PromptUpdateDetails[str]: if processor is None: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index e23765cc4fb5e..3b2ad695f83ef 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -249,20 +249,15 @@ class H2OVLProcessor(BaseInternVLProcessor): def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] - def get_image_repl_features( + def get_image_repl( self, feature_size: int, num_patches: Optional[int], - ) -> str: - return IMG_CONTEXT * feature_size + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END - def get_image_repl_full( - self, - feature_size: int, - num_patches: Optional[int], - ) -> str: - features = self.get_image_repl_features(feature_size, num_patches) - return IMG_START + features + IMG_END + return PromptUpdateDetails(full=repl_full, features=repl_features) def resolve_min_max_num( self, @@ -501,12 +496,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] if num_patches is not None: assert isinstance(num_patches, int) - return PromptUpdateDetails( - full=hf_processor.get_image_repl_full(feature_size, - num_patches), - features=hf_processor.get_image_repl_features( - feature_size, num_patches), - ) + return hf_processor.get_image_repl(feature_size, num_patches) return [ PromptReplacement( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d31b623b5bc71..e8ec91736d58f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -9,14 +9,13 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar, - Union) +from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchFeature, PretrainedConfig, TensorType +from transformers import BatchEncoding, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -36,10 +35,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features IMG_START = '' IMG_END = '' @@ -51,16 +52,26 @@ IMAGENET_STD = (0.229, 0.224, 0.225) class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor + pixel_values_flat: torch.Tensor """ Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ - patches_per_image: List[int] + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ - List of number of total patches for each image in the batch. + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_images, num_embeds)` """ + num_embeds: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size, num_images)`""" + class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -286,19 +297,11 @@ class BaseInternVLProcessor(ABC): raise NotImplementedError @abstractmethod - def get_image_repl_features( + def get_image_repl( self, feature_size: int, num_patches: Optional[int], - ) -> str: - raise NotImplementedError - - @abstractmethod - def get_image_repl_full( - self, - feature_size: int, - num_patches: Optional[int], - ) -> str: + ) -> PromptUpdateDetails[str]: raise NotImplementedError def resolve_min_max_num( @@ -394,7 +397,7 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchFeature: + ) -> Mapping[str, NestedTensors]: if text is None: text = [] if not isinstance(text, list): @@ -413,28 +416,41 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), - "image_num_patches": list(map(len, pixel_values_lst)), + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), } + tokenizer = self.tokenizer + image_token_id = self.image_token_id + + num_embeds = list[int]() + embed_is_patch = list[torch.Tensor]() + for pixel_values in pixel_values_lst: num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token - image_repl = self.get_image_repl_full(feature_size, - num_patches) - text = [t.replace('', image_repl, 1) for t in text] + image_repl = self.get_image_repl(feature_size, num_patches) + feature_tokens = tokenizer.encode(image_repl.features, + add_special_tokens=False) + + text = [t.replace('', image_repl.full, 1) for t in text] + num_embeds.append(len(feature_tokens)) + embed_is_patch.append( + torch.tensor(feature_tokens) == image_token_id) + + image_inputs["num_embeds"] = torch.tensor(num_embeds) + image_inputs["embed_is_patch"] = embed_is_patch text_inputs = self.tokenizer(text) - return BatchFeature( - { - **text_inputs, - **image_inputs, - }, - tensor_type=return_tensors, - ) + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } class InternVLProcessor(BaseInternVLProcessor): @@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor): def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] - def get_image_repl_features( + def get_image_repl( self, feature_size: int, num_patches: Optional[int], - ) -> str: - return IMG_CONTEXT * feature_size + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END - def get_image_repl_full( - self, - feature_size: int, - num_patches: Optional[int], - ) -> str: - features = self.get_image_repl_features(feature_size, num_patches) - return IMG_START + features + IMG_END + return PromptUpdateDetails(full=repl_full, features=repl_features) class BaseInternVLProcessingInfo(BaseProcessingInfo): @@ -566,16 +577,15 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], - ) -> BatchFeature: + ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) - image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id - image_data = mm_data.get("images", []) - assert isinstance(image_data, list) + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id # Since there may be extra tokens in the feature placeholders, # we need to pass the image token ID to the model to select the @@ -586,7 +596,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_mm_fields_config( self, - hf_inputs: BatchFeature, + hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -596,6 +606,8 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + num_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -637,12 +649,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): if num_patches is not None: assert isinstance(num_patches, int) - return PromptUpdateDetails( - full=hf_processor.get_image_repl_full(feature_size, - num_patches), - features=hf_processor.get_image_repl_features( - feature_size, num_patches), - ) + return hf_processor.get_image_repl(feature_size, num_patches) return [ PromptReplacement( @@ -832,6 +839,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + num_embeds = kwargs.pop("num_embeds", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values_flat is None and image_embeds is None: @@ -858,35 +867,47 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): if not isinstance(image_num_patches, (torch.Tensor, list)): raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(pixel_values_flat)}") + f"Got type: {type(image_num_patches)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + if not isinstance(num_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_embeds. " + f"Got type: {type(num_embeds)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) return InternVLImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values_flat, concat=True)), - patches_per_image=flatten_bn(image_num_patches, - concat=True).tolist()) + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat), + num_patches=image_num_patches, + embed_is_patch=embed_is_patch, + num_embeds=num_embeds, + ) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, - ) -> tuple[torch.Tensor, ...]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None - image_embeds = self.extract_feature(image_input["data"]) + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) - patches_per_image = image_input["patches_per_image"] + num_patches = image_input["num_patches"] # Only one image in the current batch - if len(patches_per_image) == 1: - image_embeds = image_embeds.view( + if len(num_patches) == 1: + return image_embeds.view( -1, self.config.text_config.hidden_size).unsqueeze(0) - return image_embeds # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. @@ -894,10 +915,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ - num_patches * feature_size for num_patches in patches_per_image + num_patches * feature_size for num_patches in num_patches ] - image_embeds = image_embeds.split(image_feature_sizes) - return image_embeds + return image_embeds.split(image_feature_sizes) def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: @@ -911,8 +931,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + image_features = self._process_image_input(image_input) + + if (kwargs.get("v0_path", False) + or image_input["type"] != "pixel_values"): + return image_features + + return flatten_2d_lists( + scatter_patch_features(*args) for args in zip( + image_features, + image_input["num_embeds"], + image_input["embed_is_patch"], + )) def get_input_embeddings( self, @@ -924,8 +955,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): assert self.img_context_token_id is not None self._set_visual_token_mask(input_ids) inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.img_context_token_id) + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.img_context_token_id, + ) return inputs_embeds def forward( @@ -944,6 +978,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 0f5cbf082d9d4..9d04f30c8f3fe 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -36,11 +36,11 @@ class NVLMProcessor(BaseInternVLProcessor): def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_PAD] - def get_image_repl_features( + def get_image_repl( self, feature_size: int, num_patches: Optional[int], - ) -> str: + ) -> PromptUpdateDetails[str]: if num_patches is None: raise NotImplementedError("Embedding inputs are not supported") @@ -55,14 +55,9 @@ class NVLMProcessor(BaseInternVLProcessor): # We include the start and end as well because "<", "tile"], resulting in assertion error # when trying to find "" + features + "" + repl = "" + features + "" - def get_image_repl_full( - self, - feature_size: int, - num_patches: Optional[int], - ) -> str: - return self.get_image_repl_features(feature_size, num_patches) + return PromptUpdateDetails(full=repl, features=repl) class NVLMProcessingInfo(BaseInternVLProcessingInfo): @@ -180,11 +175,11 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): if num_patches is not None: assert isinstance(num_patches, int) + repl = hf_processor.get_image_repl(feature_size, num_patches) + return PromptUpdateDetails( - full=hf_processor.get_image_repl_full(feature_size, - num_patches) + "\n", - features=hf_processor.get_image_repl_features( - feature_size, num_patches) + "\n", + full=repl.full + "\n", + features=repl.features + "\n", ) # See note in dummy data regarding why we have the extra newline diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index db995957a7f80..fec77acc1d197 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -103,13 +103,13 @@ The token sequence or text to update. @dataclass -class PromptUpdateDetails: +class PromptUpdateDetails(Generic[_S]): """Details about the token sequence or text that are part of the update.""" - full: PromptSeq + full: _S """The full content.""" - features: PromptSeq + features: _S """ The part of the content that corresponds to feature placeholders; this will be replaced by the output of the vision encoder during model @@ -117,7 +117,7 @@ class PromptUpdateDetails: """ @staticmethod - def from_seq(seq: PromptSeq) -> "PromptUpdateDetails": + def from_seq(seq: _S) -> "PromptUpdateDetails[_S]": return PromptUpdateDetails(full=seq, features=seq)