[VLM] Initialize video input support for InternVL models (#18499)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2025-05-25 12:51:25 +08:00 committed by GitHub
parent 6ab681bcbe
commit 75f81750f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 596 additions and 62 deletions

View File

@ -527,7 +527,7 @@ Specified using `--task generate`.
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | |
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | |
@ -577,6 +577,9 @@ Specified using `--task generate`.
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
!!! note
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.

View File

@ -330,22 +330,26 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
model_name = "OpenGVLab/InternVL3-2B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<image>"
elif modality == "video":
placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
'content': f"{placeholder}\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
@ -357,6 +361,9 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [
token_id for token_id in stop_token_ids if token_id is not None
]
return ModelRequestData(
engine_args=engine_args,

View File

@ -349,6 +349,17 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"intern_vl-video": VLMTestInfo(
models=[
"OpenGVLab/InternVL3-1B",
],
test_type=VLMTestType.VIDEO,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
video_idx_to_prompt=lambda idx: "<video>",
max_model_len=8192,
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"kimi_vl": VLMTestInfo(
models=["moonshotai/Kimi-VL-A3B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

View File

@ -7,6 +7,8 @@ import types
from pathlib import PosixPath
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import regex as re
import torch
from PIL.Image import Image
@ -495,30 +497,74 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Union[Image, list[Image]],
**kwargs):
def __call__(
self,
text: str,
images: Union[Image, list[Image]] = None,
videos: Union[npt.NDArray, list[npt.NDArray]] = None,
**kwargs,
):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_internvl)
image_to_pixel_values_internvl, video_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
videos = [videos] if isinstance(videos, np.ndarray) else videos
if images is not None:
pixel_values_images = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_images = [
pixel_value.shape[0] for pixel_value in pixel_values_images
]
else:
pixel_values_images, num_patches_images = [], []
if videos is not None:
pixel_values_videos = [
video_to_pixel_values_internvl(
video,
input_size=self.image_size,
min_num=1,
max_num=1,
use_thumbnail=False,
) for video in videos
]
num_patches_videos = [
pixel_value.shape[0] for pixel_value in pixel_values_videos
]
else:
pixel_values_videos, num_patches_videos = [], []
pixel_values = []
while ("<image>" in text) or ("<video>" in text):
image_index = text.find("<image>")
video_index = text.find("<video>")
if image_index == -1 or (video_index > -1
and video_index < image_index):
num_patches = num_patches_videos.pop(0)
pixel_values.append(pixel_values_videos.pop(0))
context_tokens = IMG_START + \
IMG_CONTEXT * self.num_image_token + IMG_END
video_tokens = ''.join([
f'Frame{i+1}: {context_tokens}'
for i in range(num_patches)
])
text = text.replace('<video>', video_tokens, 1)
else:
num_patches = num_patches_images.pop(0)
pixel_values.append(pixel_values_images.pop(0))
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt

View File

@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL3-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"moonshotai/Kimi-VL-A3B-Instruct",

View File

@ -334,7 +334,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
extras={"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501

View File

@ -556,6 +556,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":

View File

@ -25,9 +25,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
BaseInternVLDummyInputsBuilder,
BaseInternVLMultiModalProcessor,
BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor, build_transform,
InternVLChatModel, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios)
@ -430,8 +431,8 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
):
class H2OVLMultiModalProcessor(
BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
def _get_prompt_updates(
self,
@ -514,7 +515,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
@MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor,
info=H2OVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
dummy_inputs=BaseInternVLDummyInputsBuilder)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(

View File

@ -8,8 +8,9 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, TypeVar, Union
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
import numpy.typing as npt
import torch
import torch.nn as nn
import torchvision.transforms as T
@ -74,6 +75,33 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
class InternVLVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_flat: torch.Tensor
"""
Shape:
`(batch_size * num_video * num_frames, num_channels, height, width)`
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class InternVLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
InternVLVideoEmbeddingInputs]
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
@ -231,6 +259,33 @@ def image_to_pixel_values_internvl(
return pixel_values
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def video_to_pixel_values_internvl(
video: npt.NDArray,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
) -> torch.Tensor:
target_ratios = get_internvl_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
frames_list = list[Image.Image]()
for frame in video:
pil_frame = dynamic_preprocess_internvl(
Image.fromarray(frame, mode="RGB"),
target_ratios=target_ratios,
image_size=input_size,
use_thumbnail=use_thumbnail,
)
assert len(pil_frame) == 1
frames_list.extend(pil_frame)
pixel_values = torch.stack([transform(image) for image in frames_list])
return pixel_values
class BaseInternVLProcessor(ABC):
"""
This model doesn't define its own HF processor,
@ -375,24 +430,14 @@ class BaseInternVLProcessor(ABC):
) for image in images
]
def __call__(
def _preprocess_image(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
text: list[str],
images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
) -> tuple[list[str], dict[str, torch.Tensor]]:
if len(images) == 0:
image_inputs = {}
else:
@ -415,6 +460,34 @@ class BaseInternVLProcessor(ABC):
image_repl = self.get_image_repl(feature_size, num_patches)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
return text, image_inputs
def _make_batch_input(self,
input_item: Optional[Union[Any, list[Any]]] = None):
if input_item is None:
input_item = []
if not isinstance(input_item, list):
input_item = [input_item]
return input_item
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
text=text,
images=images,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
text_inputs = self.tokenizer(text)
@ -425,11 +498,133 @@ class BaseInternVLProcessor(ABC):
class InternVLProcessor(BaseInternVLProcessor):
"""
HF Processor for InternVLChatModel with extended video processing logic.
Code for video processing is adapted from video example:
https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
video_token: Optional[str] = None,
) -> None:
super().__init__(
config=config,
tokenizer=tokenizer,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
# add extra video token for video processing
self.video_token = video_token
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
@property
def video_token_id(self) -> Optional[int]:
if self.video_token is None:
return None
return self.tokenizer.get_vocab().get(self.video_token, None)
@property
def supports_video(self) -> bool:
return self.video_token_id is not None
def _videos_to_pixel_values_lst(
self,
videos: list[npt.NDArray],
dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=1,
max_dynamic_patch=1,
dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values
)
return [
video_to_pixel_values_internvl(
video,
input_size=self.image_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=False,
) for video in videos
]
def _preprocess_video(
self,
text: list[str],
videos: list[npt.NDArray],
dynamic_image_size: Optional[bool] = None,
):
if len(videos) == 0 or not self.supports_video:
video_inputs = {}
else:
pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos,
dynamic_image_size=dynamic_image_size,
)
video_inputs: dict[str, NestedTensors] = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
torch.tensor([len(item) for item in pixel_values_lst_video]),
}
for pixel_values in pixel_values_lst_video:
num_patches = pixel_values.shape[0]
video_repl = self.get_video_repl(self.num_image_token,
num_patches, self.video_token)
text = [t.replace('<video>', video_repl.full, 1) for t in text]
return text, video_inputs
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos)
]
text, image_inputs = self._preprocess_image(
text=text,
images=images,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
text, video_inputs = self._preprocess_video(
text=text,
videos=videos,
dynamic_image_size=dynamic_image_size,
)
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
**video_inputs,
}
def get_image_repl(
self,
feature_size: int,
@ -440,8 +635,24 @@ class InternVLProcessor(BaseInternVLProcessor):
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def get_video_repl(
self,
feature_size: int,
num_patches: Optional[int] = None,
video_context_token: str = IMG_CONTEXT,
) -> PromptUpdateDetails[str]:
repl_features = video_context_token * self.num_image_token
repl_features_with_sep = IMG_START + repl_features + IMG_END
# num_patches is equal to num_frames
repl_full = ''.join([
f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches)
])
return PromptUpdateDetails.select_text(repl_full, video_context_token)
class BaseInternVLProcessingInfo(BaseProcessingInfo):
"""Basic image-only ProcessingInfo for InternVL-style models."""
@abstractmethod
def get_hf_processor(
@ -497,11 +708,22 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
def get_max_image_tokens(self) -> int:
processor = self.get_hf_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=processor,
)
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
"""Basic image-only DummyInputsBuilder for InternVL-style models."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@ -525,7 +747,8 @@ class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
}
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
""" Basic image-only MultiModalProcessor for InternVL-style models."""
def _call_hf_processor(
self,
@ -614,6 +837,38 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
"""InternVL ProcessingInfo extended for video processing"""
@property
def supports_video(self):
return self.get_hf_processor().supports_video
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
return {**super().get_supported_mm_limits(), **video_limit}
def get_video_token(self) -> Optional[str]:
text_model_type = self.get_hf_config().get_text_config().model_type
if text_model_type == "qwen2":
return "<|video_pad|>"
return None
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
processor = self.get_hf_processor()
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len -
max_image_tokens) // processor.num_image_token
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
def get_hf_processor(
self,
@ -630,6 +885,8 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
kwargs["video_token"] = self.get_video_token()
return self.ctx.init_processor(
InternVLProcessor,
config=self.get_hf_config(),
@ -638,6 +895,121 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
)
class InternVLDummyInputsBuilder(
BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]):
"""InternVL DummyInputsBuilder extended for video support"""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0)
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
mm_counts=mm_counts)
if self.info.supports_video:
config = self.info.get_hf_config()
image_size: int = config.vision_config.image_size
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
num_videos = mm_counts.get("video", 0)
dummy_video = {
"video":
self._get_dummy_videos(width=image_size,
height=image_size,
num_frames=target_num_frames,
num_videos=num_videos)
}
else:
dummy_video = {}
return {**dummy_image, **dummy_video}
class InternVLMultiModalProcessor(
BaseInternVLMultiModalProcessor[InternVLProcessingInfo]):
"""InternVL MultiModalProcessor extended for video support"""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if self.info.supports_video and (
video_token_id := hf_processor.video_token_id) is not None:
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,
hf_processor_mm_kwargs)
if self.info.supports_video:
video_num_patches = hf_inputs.get("video_num_patches",
torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared(
"video", num_videos),
)
else:
video_fields = {}
return image_fields | video_fields
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
prompt_repl: list[PromptUpdate] = super()._get_prompt_updates(
mm_items, hf_processor_mm_kwargs, out_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "video_num_patches" in out_mm_kwargs:
video_num_patches = out_mm_kwargs["video_num_patches"]
assert isinstance(video_num_patches, torch.Tensor)
video_num_patches = video_num_patches.tolist()
else:
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token
num_patches = video_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return hf_processor.get_video_repl(
feature_size,
num_patches,
video_context_token=hf_processor.video_token)
if self.info.supports_video:
prompt_repl.append(
PromptReplacement(
modality="video",
target="<video>",
replacement=get_video_replacement_internvl,
))
return prompt_repl
@MULTIMODAL_REGISTRY.register_processor(
InternVLMultiModalProcessor,
info=InternVLProcessingInfo,
@ -681,6 +1053,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.video_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -825,10 +1199,55 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat_video is None and video_embeds is None:
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternVLImageEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
return InternVLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat_video),
num_patches=video_num_patches,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -840,8 +1259,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# Only one image in the current batch
if len(num_patches) == 1:
return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0)
return (image_embeds.view(-1,
self.config.text_config.hidden_size), )
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
@ -853,8 +1272,26 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
]
return image_embeds.split(image_feature_sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values_flat",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_flat_video",
) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
assert self.img_context_token_id is not None
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
@ -865,11 +1302,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return self._process_image_input(image_input)
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
@ -878,13 +1332,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_context_token_id,
context_token_ids,
)
return inputs_embeds

View File

@ -22,9 +22,10 @@ from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from .intern_vit import InternVisionModel
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor)
from .internvl import (BaseInternVLDummyInputsBuilder,
BaseInternVLMultiModalProcessor,
BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel)
IMG_PAD = "<|vision_pad|>"
@ -84,7 +85,8 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
)
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@ -110,7 +112,8 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
}
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(
BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_updates(
self,