[Bugfix] Fix embedding assignment for InternVL-based models (#15086)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-20 11:40:13 +08:00 committed by GitHub
parent 70e500cad9
commit ffa443afed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 123 additions and 106 deletions

View File

@ -169,7 +169,6 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=2048,
max_num_seqs=2,
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

View File

@ -91,8 +91,6 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name,
max_model_len=8192,
max_num_seqs=2,
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": len(image_urls)},
)

View File

@ -183,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> PromptUpdateDetails:
) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()

View File

@ -249,20 +249,15 @@ class H2OVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl_features(
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return IMG_CONTEXT * feature_size
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features)
def resolve_min_max_num(
self,
@ -501,12 +496,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(

View File

@ -9,14 +9,13 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
Union)
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -36,10 +35,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>'
IMG_END = '</img>'
@ -51,16 +52,26 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
pixel_values_flat: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
patches_per_image: List[int]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
List of number of total patches for each image in the batch.
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@ -286,19 +297,11 @@ class BaseInternVLProcessor(ABC):
raise NotImplementedError
@abstractmethod
def get_image_repl_features(
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
raise NotImplementedError
@abstractmethod
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
) -> PromptUpdateDetails[str]:
raise NotImplementedError
def resolve_min_max_num(
@ -394,7 +397,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
@ -413,28 +416,41 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs = {
"pixel_values_flat": torch.cat(pixel_values_lst),
"image_num_patches": list(map(len, pixel_values_lst)),
image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
}
tokenizer = self.tokenizer
image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl_full(feature_size,
num_patches)
text = [t.replace('<image>', image_repl, 1) for t in text]
image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
class InternVLProcessor(BaseInternVLProcessor):
@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl_features(
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return IMG_CONTEXT * feature_size
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseInternVLProcessingInfo(BaseProcessingInfo):
@ -566,16 +577,15 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
image_data = mm_data.get("images", [])
assert isinstance(image_data, list)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
@ -586,7 +596,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -596,6 +606,8 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@ -637,12 +649,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(
@ -832,6 +839,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@ -858,35 +867,47 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(pixel_values_flat)}")
f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values_flat, concat=True)),
patches_per_image=flatten_bn(image_num_patches,
concat=True).tolist())
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> tuple[torch.Tensor, ...]:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
patches_per_image = image_input["patches_per_image"]
num_patches = image_input["num_patches"]
# Only one image in the current batch
if len(patches_per_image) == 1:
image_embeds = image_embeds.view(
if len(num_patches) == 1:
return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0)
return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
@ -894,10 +915,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
num_patches * feature_size for num_patches in num_patches
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds
return image_embeds.split(image_feature_sizes)
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
@ -911,8 +931,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values"):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
@ -924,8 +955,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds
def forward(
@ -944,6 +978,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)

View File

@ -36,11 +36,11 @@ class NVLMProcessor(BaseInternVLProcessor):
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_PAD]
def get_image_repl_features(
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
) -> PromptUpdateDetails[str]:
if num_patches is None:
raise NotImplementedError("Embedding inputs are not supported")
@ -55,14 +55,9 @@ class NVLMProcessor(BaseInternVLProcessor):
# We include the start and end as well because "<Image><tile" is
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
# when trying to find "<tile" as a subsequence of "<Image><tile"
return "<Image>" + features + "</Image>"
repl = "<Image>" + features + "</Image>"
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return self.get_image_repl_features(feature_size, num_patches)
return PromptUpdateDetails(full=repl, features=repl)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
@ -180,11 +175,11 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None:
assert isinstance(num_patches, int)
repl = hf_processor.get_image_repl(feature_size, num_patches)
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n",
features=hf_processor.get_image_repl_features(
feature_size, num_patches) + "\n",
full=repl.full + "\n",
features=repl.features + "\n",
)
# See note in dummy data regarding why we have the extra newline

View File

@ -103,13 +103,13 @@ The token sequence or text to update.
@dataclass
class PromptUpdateDetails:
class PromptUpdateDetails(Generic[_S]):
"""Details about the token sequence or text that are part of the update."""
full: PromptSeq
full: _S
"""The full content."""
features: PromptSeq
features: _S
"""
The part of the content that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
@ -117,7 +117,7 @@ class PromptUpdateDetails:
"""
@staticmethod
def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(full=seq, features=seq)