diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 3849bd37a8290..1cc2562759d47 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -169,7 +169,6 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=2048,
max_num_seqs=2,
- # Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 3a17e5bab0931..98a739169d702 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -91,8 +91,6 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name,
max_model_len=8192,
max_num_seqs=2,
- # Default is False; setting it to True is not supported in V1 yet
- mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": len(image_urls)},
)
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index 62e55d64cf2ca..8db2bfb901bf3 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -183,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
- ) -> PromptUpdateDetails:
+ ) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py
index e23765cc4fb5e..3b2ad695f83ef 100644
--- a/vllm/model_executor/models/h2ovl.py
+++ b/vllm/model_executor/models/h2ovl.py
@@ -249,20 +249,15 @@ class H2OVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
- def get_image_repl_features(
+ def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
- ) -> str:
- return IMG_CONTEXT * feature_size
+ ) -> PromptUpdateDetails[str]:
+ repl_features = IMG_CONTEXT * feature_size
+ repl_full = IMG_START + repl_features + IMG_END
- def get_image_repl_full(
- self,
- feature_size: int,
- num_patches: Optional[int],
- ) -> str:
- features = self.get_image_repl_features(feature_size, num_patches)
- return IMG_START + features + IMG_END
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
def resolve_min_max_num(
self,
@@ -501,12 +496,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None:
assert isinstance(num_patches, int)
- return PromptUpdateDetails(
- full=hf_processor.get_image_repl_full(feature_size,
- num_patches),
- features=hf_processor.get_image_repl_features(
- feature_size, num_patches),
- )
+ return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index d31b623b5bc71..e8ec91736d58f 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -9,14 +9,13 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
-from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
- Union)
+from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
-from transformers import BatchFeature, PretrainedConfig, TensorType
+from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -36,10 +35,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
IMG_START = '
'
IMG_END = ''
@@ -51,16 +52,26 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
- data: torch.Tensor
+ pixel_values_flat: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
- patches_per_image: List[int]
+
+ num_patches: torch.Tensor
+ """Shape: `(batch_size * num_images)`"""
+
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
- List of number of total patches for each image in the batch.
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_images, num_embeds)`
"""
+ num_embeds: Union[torch.Tensor, list[torch.Tensor]]
+ """Shape: `(batch_size, num_images)`"""
+
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -286,19 +297,11 @@ class BaseInternVLProcessor(ABC):
raise NotImplementedError
@abstractmethod
- def get_image_repl_features(
+ def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
- ) -> str:
- raise NotImplementedError
-
- @abstractmethod
- def get_image_repl_full(
- self,
- feature_size: int,
- num_patches: Optional[int],
- ) -> str:
+ ) -> PromptUpdateDetails[str]:
raise NotImplementedError
def resolve_min_max_num(
@@ -394,7 +397,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
- ) -> BatchFeature:
+ ) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
@@ -413,28 +416,41 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
- image_inputs = {
- "pixel_values_flat": torch.cat(pixel_values_lst),
- "image_num_patches": list(map(len, pixel_values_lst)),
+ image_inputs: dict[str, NestedTensors] = {
+ "pixel_values_flat":
+ torch.cat(pixel_values_lst),
+ "image_num_patches":
+ torch.tensor([len(item) for item in pixel_values_lst]),
}
+ tokenizer = self.tokenizer
+ image_token_id = self.image_token_id
+
+ num_embeds = list[int]()
+ embed_is_patch = list[torch.Tensor]()
+
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
- image_repl = self.get_image_repl_full(feature_size,
- num_patches)
- text = [t.replace('', image_repl, 1) for t in text]
+ image_repl = self.get_image_repl(feature_size, num_patches)
+ feature_tokens = tokenizer.encode(image_repl.features,
+ add_special_tokens=False)
+
+ text = [t.replace('', image_repl.full, 1) for t in text]
+ num_embeds.append(len(feature_tokens))
+ embed_is_patch.append(
+ torch.tensor(feature_tokens) == image_token_id)
+
+ image_inputs["num_embeds"] = torch.tensor(num_embeds)
+ image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
- return BatchFeature(
- {
- **text_inputs,
- **image_inputs,
- },
- tensor_type=return_tensors,
- )
+ return {
+ **BatchEncoding(text_inputs, tensor_type=return_tensors),
+ **image_inputs,
+ }
class InternVLProcessor(BaseInternVLProcessor):
@@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
- def get_image_repl_features(
+ def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
- ) -> str:
- return IMG_CONTEXT * feature_size
+ ) -> PromptUpdateDetails[str]:
+ repl_features = IMG_CONTEXT * feature_size
+ repl_full = IMG_START + repl_features + IMG_END
- def get_image_repl_full(
- self,
- feature_size: int,
- num_patches: Optional[int],
- ) -> str:
- features = self.get_image_repl_features(feature_size, num_patches)
- return IMG_START + features + IMG_END
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseInternVLProcessingInfo(BaseProcessingInfo):
@@ -566,16 +577,15 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
- ) -> BatchFeature:
+ ) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
- image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
- image_data = mm_data.get("images", [])
- assert isinstance(image_data, list)
+ hf_processor = self.info.get_hf_processor(**mm_kwargs)
+ image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
@@ -586,7 +596,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config(
self,
- hf_inputs: BatchFeature,
+ hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@@ -596,6 +606,8 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
+ num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@@ -637,12 +649,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None:
assert isinstance(num_patches, int)
- return PromptUpdateDetails(
- full=hf_processor.get_image_repl_full(feature_size,
- num_patches),
- features=hf_processor.get_image_repl_features(
- feature_size, num_patches),
- )
+ return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(
@@ -832,6 +839,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@@ -858,35 +867,47 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
- f"Got type: {type(pixel_values_flat)}")
+ f"Got type: {type(image_num_patches)}")
+
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ if not isinstance(num_embeds, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of num_embeds. "
+ f"Got type: {type(num_embeds)}")
+
+ pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
+ image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternVLImagePixelInputs(
type="pixel_values",
- data=self._validate_pixel_values(
- flatten_bn(pixel_values_flat, concat=True)),
- patches_per_image=flatten_bn(image_num_patches,
- concat=True).tolist())
+ pixel_values_flat=self._validate_pixel_values(
+ pixel_values_flat),
+ num_patches=image_num_patches,
+ embed_is_patch=embed_is_patch,
+ num_embeds=num_embeds,
+ )
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
- ) -> tuple[torch.Tensor, ...]:
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
- image_embeds = self.extract_feature(image_input["data"])
+ image_embeds = self.extract_feature(image_input["pixel_values_flat"])
- patches_per_image = image_input["patches_per_image"]
+ num_patches = image_input["num_patches"]
# Only one image in the current batch
- if len(patches_per_image) == 1:
- image_embeds = image_embeds.view(
+ if len(num_patches) == 1:
+ return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0)
- return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
@@ -894,10 +915,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
- num_patches * feature_size for num_patches in patches_per_image
+ num_patches * feature_size for num_patches in num_patches
]
- image_embeds = image_embeds.split(image_feature_sizes)
- return image_embeds
+ return image_embeds.split(image_feature_sizes)
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
@@ -911,8 +931,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
- vision_embeddings = self._process_image_input(image_input)
- return vision_embeddings
+
+ image_features = self._process_image_input(image_input)
+
+ if (kwargs.get("v0_path", False)
+ or image_input["type"] != "pixel_values"):
+ return image_features
+
+ return flatten_2d_lists(
+ scatter_patch_features(*args) for args in zip(
+ image_features,
+ image_input["num_embeds"],
+ image_input["embed_is_patch"],
+ ))
def get_input_embeddings(
self,
@@ -924,8 +955,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
- input_ids, inputs_embeds, multimodal_embeddings,
- self.img_context_token_id)
+ input_ids,
+ inputs_embeds,
+ select_patch_features(multimodal_embeddings),
+ self.img_context_token_id,
+ )
return inputs_embeds
def forward(
@@ -944,6 +978,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
+ kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index 0f5cbf082d9d4..9d04f30c8f3fe 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -36,11 +36,11 @@ class NVLMProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_PAD]
- def get_image_repl_features(
+ def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
- ) -> str:
+ ) -> PromptUpdateDetails[str]:
if num_patches is None:
raise NotImplementedError("Embedding inputs are not supported")
@@ -55,14 +55,9 @@ class NVLMProcessor(BaseInternVLProcessor):
# We include the start and end as well because "<", "tile"], resulting in assertion error
# when trying to find "" + features + ""
+ repl = "" + features + ""
- def get_image_repl_full(
- self,
- feature_size: int,
- num_patches: Optional[int],
- ) -> str:
- return self.get_image_repl_features(feature_size, num_patches)
+ return PromptUpdateDetails(full=repl, features=repl)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
@@ -180,11 +175,11 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None:
assert isinstance(num_patches, int)
+ repl = hf_processor.get_image_repl(feature_size, num_patches)
+
return PromptUpdateDetails(
- full=hf_processor.get_image_repl_full(feature_size,
- num_patches) + "\n",
- features=hf_processor.get_image_repl_features(
- feature_size, num_patches) + "\n",
+ full=repl.full + "\n",
+ features=repl.features + "\n",
)
# See note in dummy data regarding why we have the extra newline
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index db995957a7f80..fec77acc1d197 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -103,13 +103,13 @@ The token sequence or text to update.
@dataclass
-class PromptUpdateDetails:
+class PromptUpdateDetails(Generic[_S]):
"""Details about the token sequence or text that are part of the update."""
- full: PromptSeq
+ full: _S
"""The full content."""
- features: PromptSeq
+ features: _S
"""
The part of the content that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
@@ -117,7 +117,7 @@ class PromptUpdateDetails:
"""
@staticmethod
- def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
+ def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(full=seq, features=seq)