[Bugfix] Fix embedding assignment for InternVL-based models (#15086)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-20 11:40:13 +08:00 committed by GitHub
parent 70e500cad9
commit ffa443afed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 123 additions and 106 deletions

View File

@ -169,7 +169,6 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
model=model_name, model=model_name,
max_model_len=2048, max_model_len=2048,
max_num_seqs=2, max_num_seqs=2,
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True}, mm_processor_kwargs={"do_pan_and_scan": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )

View File

@ -91,8 +91,6 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name, model=model_name,
max_model_len=8192, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )

View File

@ -183,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Optional[Gemma3Processor], processor: Optional[Gemma3Processor],
) -> PromptUpdateDetails: ) -> PromptUpdateDetails[str]:
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()

View File

@ -249,20 +249,15 @@ class H2OVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int: def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT] return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl_features( def get_image_repl(
self, self,
feature_size: int, feature_size: int,
num_patches: Optional[int], num_patches: Optional[int],
) -> str: ) -> PromptUpdateDetails[str]:
return IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
def get_image_repl_full( return PromptUpdateDetails(full=repl_full, features=repl_features)
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
def resolve_min_max_num( def resolve_min_max_num(
self, self,
@ -501,12 +496,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptUpdateDetails( return hf_processor.get_image_repl(feature_size, num_patches)
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return [ return [
PromptReplacement( PromptReplacement(

View File

@ -9,14 +9,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar, from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -36,10 +35,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
@ -51,16 +52,26 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict): class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor pixel_values_flat: torch.Tensor
""" """
Shape: Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
""" """
patches_per_image: List[int]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
List of number of total patches for each image in the batch. A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
""" """
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
@ -286,19 +297,11 @@ class BaseInternVLProcessor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_repl_features( def get_image_repl(
self, self,
feature_size: int, feature_size: int,
num_patches: Optional[int], num_patches: Optional[int],
) -> str: ) -> PromptUpdateDetails[str]:
raise NotImplementedError
@abstractmethod
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
raise NotImplementedError raise NotImplementedError
def resolve_min_max_num( def resolve_min_max_num(
@ -394,7 +397,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature: ) -> Mapping[str, NestedTensors]:
if text is None: if text is None:
text = [] text = []
if not isinstance(text, list): if not isinstance(text, list):
@ -413,28 +416,41 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
image_inputs = { image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat": torch.cat(pixel_values_lst), "pixel_values_flat":
"image_num_patches": list(map(len, pixel_values_lst)), torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
} }
tokenizer = self.tokenizer
image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0] num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl_full(feature_size, image_repl = self.get_image_repl(feature_size, num_patches)
num_patches) feature_tokens = tokenizer.encode(image_repl.features,
text = [t.replace('<image>', image_repl, 1) for t in text] add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
return BatchFeature( return {
{ **BatchEncoding(text_inputs, tensor_type=return_tensors),
**text_inputs, **image_inputs,
**image_inputs, }
},
tensor_type=return_tensors,
)
class InternVLProcessor(BaseInternVLProcessor): class InternVLProcessor(BaseInternVLProcessor):
@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int: def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT] return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl_features( def get_image_repl(
self, self,
feature_size: int, feature_size: int,
num_patches: Optional[int], num_patches: Optional[int],
) -> str: ) -> PromptUpdateDetails[str]:
return IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
def get_image_repl_full( return PromptUpdateDetails(full=repl_full, features=repl_features)
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
class BaseInternVLProcessingInfo(BaseProcessingInfo): class BaseInternVLProcessingInfo(BaseProcessingInfo):
@ -566,16 +577,15 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_data = mm_data.get("images", []) image_token_id = hf_processor.image_token_id
assert isinstance(image_data, list)
# Since there may be extra tokens in the feature placeholders, # Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the # we need to pass the image token ID to the model to select the
@ -586,7 +596,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -596,6 +606,8 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
@ -637,12 +649,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptUpdateDetails( return hf_processor.get_image_repl(feature_size, num_patches)
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return [ return [
PromptReplacement( PromptReplacement(
@ -832,6 +839,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
@ -858,35 +867,47 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(image_num_patches, (torch.Tensor, list)): if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. " raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(pixel_values_flat)}") f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values( pixel_values_flat=self._validate_pixel_values(
flatten_bn(pixel_values_flat, concat=True)), pixel_values_flat),
patches_per_image=flatten_bn(image_num_patches, num_patches=image_num_patches,
concat=True).tolist()) embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_image_input( def _process_image_input(
self, self,
image_input: InternVLImageInputs, image_input: InternVLImageInputs,
) -> tuple[torch.Tensor, ...]: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_model is not None assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"]) image_embeds = self.extract_feature(image_input["pixel_values_flat"])
patches_per_image = image_input["patches_per_image"] num_patches = image_input["num_patches"]
# Only one image in the current batch # Only one image in the current batch
if len(patches_per_image) == 1: if len(num_patches) == 1:
image_embeds = image_embeds.view( return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0) -1, self.config.text_config.hidden_size).unsqueeze(0)
return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image # NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding. # by the size of each embedding.
@ -894,10 +915,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = image_embeds.view(-1, image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size) self.config.text_config.hidden_size)
image_feature_sizes = [ image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image num_patches * feature_size for num_patches in num_patches
] ]
image_embeds = image_embeds.split(image_feature_sizes) return image_embeds.split(image_feature_sizes)
return image_embeds
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono: if self.is_mono:
@ -911,8 +931,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values"):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -924,8 +955,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
assert self.img_context_token_id is not None assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids,
self.img_context_token_id) inputs_embeds,
select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
@ -944,6 +978,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)

View File

@ -36,11 +36,11 @@ class NVLMProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int: def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_PAD] return self.tokenizer.get_vocab()[IMG_PAD]
def get_image_repl_features( def get_image_repl(
self, self,
feature_size: int, feature_size: int,
num_patches: Optional[int], num_patches: Optional[int],
) -> str: ) -> PromptUpdateDetails[str]:
if num_patches is None: if num_patches is None:
raise NotImplementedError("Embedding inputs are not supported") raise NotImplementedError("Embedding inputs are not supported")
@ -55,14 +55,9 @@ class NVLMProcessor(BaseInternVLProcessor):
# We include the start and end as well because "<Image><tile" is # We include the start and end as well because "<Image><tile" is
# tokenized as ["<Image", "><", "tile"], resulting in assertion error # tokenized as ["<Image", "><", "tile"], resulting in assertion error
# when trying to find "<tile" as a subsequence of "<Image><tile" # when trying to find "<tile" as a subsequence of "<Image><tile"
return "<Image>" + features + "</Image>" repl = "<Image>" + features + "</Image>"
def get_image_repl_full( return PromptUpdateDetails(full=repl, features=repl)
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return self.get_image_repl_features(feature_size, num_patches)
class NVLMProcessingInfo(BaseInternVLProcessingInfo): class NVLMProcessingInfo(BaseInternVLProcessingInfo):
@ -180,11 +175,11 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
repl = hf_processor.get_image_repl(feature_size, num_patches)
return PromptUpdateDetails( return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size, full=repl.full + "\n",
num_patches) + "\n", features=repl.features + "\n",
features=hf_processor.get_image_repl_features(
feature_size, num_patches) + "\n",
) )
# See note in dummy data regarding why we have the extra newline # See note in dummy data regarding why we have the extra newline

View File

@ -103,13 +103,13 @@ The token sequence or text to update.
@dataclass @dataclass
class PromptUpdateDetails: class PromptUpdateDetails(Generic[_S]):
"""Details about the token sequence or text that are part of the update.""" """Details about the token sequence or text that are part of the update."""
full: PromptSeq full: _S
"""The full content.""" """The full content."""
features: PromptSeq features: _S
""" """
The part of the content that corresponds to feature placeholders; The part of the content that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model this will be replaced by the output of the vision encoder during model
@ -117,7 +117,7 @@ class PromptUpdateDetails:
""" """
@staticmethod @staticmethod
def from_seq(seq: PromptSeq) -> "PromptUpdateDetails": def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(full=seq, features=seq) return PromptUpdateDetails(full=seq, features=seq)