mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[VLM] Initialize video input support for InternVL models (#18499)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
6ab681bcbe
commit
75f81750f3
@ -527,7 +527,7 @@ Specified using `--task generate`.
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | |
|
||||
@ -577,6 +577,9 @@ Specified using `--task generate`.
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
|
||||
!!! note
|
||||
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
||||
|
||||
!!! note
|
||||
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
|
||||
|
||||
|
||||
@ -330,22 +330,26 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# InternVL
|
||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
model_name = "OpenGVLab/InternVL3-2B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<image>"
|
||||
elif modality == "video":
|
||||
placeholder = "<video>"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': f"<image>\n{question}"
|
||||
'content': f"{placeholder}\n{question}"
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
@ -357,6 +361,9 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
stop_token_ids = [
|
||||
token_id for token_id in stop_token_ids if token_id is not None
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
|
||||
@ -349,6 +349,17 @@ VLM_TEST_SETTINGS = {
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
"intern_vl-video": VLMTestInfo(
|
||||
models=[
|
||||
"OpenGVLab/InternVL3-1B",
|
||||
],
|
||||
test_type=VLMTestType.VIDEO,
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||
video_idx_to_prompt=lambda idx: "<video>",
|
||||
max_model_len=8192,
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
"kimi_vl": VLMTestInfo(
|
||||
models=["moonshotai/Kimi-VL-A3B-Instruct"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
|
||||
@ -7,6 +7,8 @@ import types
|
||||
from pathlib import PosixPath
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
@ -495,30 +497,74 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Union[Image, list[Image]],
|
||||
**kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
text: str,
|
||||
images: Union[Image, list[Image]] = None,
|
||||
videos: Union[npt.NDArray, list[npt.NDArray]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START,
|
||||
image_to_pixel_values_internvl)
|
||||
image_to_pixel_values_internvl, video_to_pixel_values_internvl)
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values_internvl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
]
|
||||
videos = [videos] if isinstance(videos, np.ndarray) else videos
|
||||
if images is not None:
|
||||
pixel_values_images = [
|
||||
image_to_pixel_values_internvl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_images = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values_images
|
||||
]
|
||||
else:
|
||||
pixel_values_images, num_patches_images = [], []
|
||||
|
||||
if videos is not None:
|
||||
pixel_values_videos = [
|
||||
video_to_pixel_values_internvl(
|
||||
video,
|
||||
input_size=self.image_size,
|
||||
min_num=1,
|
||||
max_num=1,
|
||||
use_thumbnail=False,
|
||||
) for video in videos
|
||||
]
|
||||
num_patches_videos = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values_videos
|
||||
]
|
||||
else:
|
||||
pixel_values_videos, num_patches_videos = [], []
|
||||
|
||||
pixel_values = []
|
||||
while ("<image>" in text) or ("<video>" in text):
|
||||
image_index = text.find("<image>")
|
||||
video_index = text.find("<video>")
|
||||
if image_index == -1 or (video_index > -1
|
||||
and video_index < image_index):
|
||||
num_patches = num_patches_videos.pop(0)
|
||||
pixel_values.append(pixel_values_videos.pop(0))
|
||||
context_tokens = IMG_START + \
|
||||
IMG_CONTEXT * self.num_image_token + IMG_END
|
||||
video_tokens = ''.join([
|
||||
f'Frame{i+1}: {context_tokens}'
|
||||
for i in range(num_patches)
|
||||
])
|
||||
text = text.replace('<video>', video_tokens, 1)
|
||||
else:
|
||||
num_patches = num_patches_images.pop(0)
|
||||
pixel_values.append(pixel_values_images.pop(0))
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
|
||||
"ibm-granite/granite-speech-3.3-8b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"OpenGVLab/InternVL3-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
|
||||
@ -334,7 +334,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B",
|
||||
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
|
||||
@ -556,6 +556,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return "(<audio>./</audio>)"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "internvl_chat":
|
||||
return "<video>"
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
if model_type == "qwen2_5_omni":
|
||||
|
||||
@ -25,9 +25,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
|
||||
BaseInternVLDummyInputsBuilder,
|
||||
BaseInternVLMultiModalProcessor,
|
||||
BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
||||
InternVLChatModel, InternVLDummyInputsBuilder,
|
||||
InternVLMultiModalProcessor, build_transform,
|
||||
InternVLChatModel, build_transform,
|
||||
find_closest_aspect_ratio, get_internvl_target_ratios)
|
||||
|
||||
|
||||
@ -430,8 +431,8 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
)
|
||||
|
||||
|
||||
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
):
|
||||
class H2OVLMultiModalProcessor(
|
||||
BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -514,7 +515,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
H2OVLMultiModalProcessor,
|
||||
info=H2OVLProcessingInfo,
|
||||
dummy_inputs=InternVLDummyInputsBuilder)
|
||||
dummy_inputs=BaseInternVLDummyInputsBuilder)
|
||||
class H2OVLChatModel(InternVLChatModel):
|
||||
|
||||
def _init_vision_model(
|
||||
|
||||
@ -8,8 +8,9 @@
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, TypeVar, Union
|
||||
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
|
||||
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
@ -74,6 +75,33 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||
InternVLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class InternVLVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_flat: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_video * num_frames, num_channels, height, width)`
|
||||
"""
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
class InternVLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
|
||||
InternVLVideoEmbeddingInputs]
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size: int):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
@ -231,6 +259,33 @@ def image_to_pixel_values_internvl(
|
||||
return pixel_values
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def video_to_pixel_values_internvl(
|
||||
video: npt.NDArray,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
frames_list = list[Image.Image]()
|
||||
for frame in video:
|
||||
pil_frame = dynamic_preprocess_internvl(
|
||||
Image.fromarray(frame, mode="RGB"),
|
||||
target_ratios=target_ratios,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
assert len(pil_frame) == 1
|
||||
frames_list.extend(pil_frame)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in frames_list])
|
||||
return pixel_values
|
||||
|
||||
|
||||
class BaseInternVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
@ -375,24 +430,14 @@ class BaseInternVLProcessor(ABC):
|
||||
) for image in images
|
||||
]
|
||||
|
||||
def __call__(
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
text: list[str],
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
) -> tuple[list[str], dict[str, torch.Tensor]]:
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
@ -415,6 +460,34 @@ class BaseInternVLProcessor(ABC):
|
||||
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self,
|
||||
input_item: Optional[Union[Any, list[Any]]] = None):
|
||||
if input_item is None:
|
||||
input_item = []
|
||||
if not isinstance(input_item, list):
|
||||
input_item = [input_item]
|
||||
return input_item
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
text, images = [self._make_batch_input(x) for x in (text, images)]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
text=text,
|
||||
images=images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
@ -425,11 +498,133 @@ class BaseInternVLProcessor(ABC):
|
||||
|
||||
|
||||
class InternVLProcessor(BaseInternVLProcessor):
|
||||
"""
|
||||
HF Processor for InternVLChatModel with extended video processing logic.
|
||||
|
||||
Code for video processing is adapted from video example:
|
||||
https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
video_token: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
# add extra video token for video processing
|
||||
self.video_token = video_token
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
@property
|
||||
def video_token_id(self) -> Optional[int]:
|
||||
if self.video_token is None:
|
||||
return None
|
||||
return self.tokenizer.get_vocab().get(self.video_token, None)
|
||||
|
||||
@property
|
||||
def supports_video(self) -> bool:
|
||||
return self.video_token_id is not None
|
||||
|
||||
def _videos_to_pixel_values_lst(
|
||||
self,
|
||||
videos: list[npt.NDArray],
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=1,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
video_to_pixel_values_internvl(
|
||||
video,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=False,
|
||||
) for video in videos
|
||||
]
|
||||
|
||||
def _preprocess_video(
|
||||
self,
|
||||
text: list[str],
|
||||
videos: list[npt.NDArray],
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
if len(videos) == 0 or not self.supports_video:
|
||||
video_inputs = {}
|
||||
else:
|
||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||
videos,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
video_inputs: dict[str, NestedTensors] = {
|
||||
"pixel_values_flat_video":
|
||||
torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst_video]),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
num_patches = pixel_values.shape[0]
|
||||
|
||||
video_repl = self.get_video_repl(self.num_image_token,
|
||||
num_patches, self.video_token)
|
||||
text = [t.replace('<video>', video_repl.full, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
text, images, videos = [
|
||||
self._make_batch_input(x) for x in (text, images, videos)
|
||||
]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
text=text,
|
||||
images=images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
text, video_inputs = self._preprocess_video(
|
||||
text=text,
|
||||
videos=videos,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return {
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
**video_inputs,
|
||||
}
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
@ -440,8 +635,24 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def get_video_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int] = None,
|
||||
video_context_token: str = IMG_CONTEXT,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = video_context_token * self.num_image_token
|
||||
repl_features_with_sep = IMG_START + repl_features + IMG_END
|
||||
# num_patches is equal to num_frames
|
||||
repl_full = ''.join([
|
||||
f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches)
|
||||
])
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, video_context_token)
|
||||
|
||||
|
||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||
"""Basic image-only ProcessingInfo for InternVL-style models."""
|
||||
|
||||
@abstractmethod
|
||||
def get_hf_processor(
|
||||
@ -497,11 +708,22 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
|
||||
|
||||
|
||||
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
"""Basic image-only DummyInputsBuilder for InternVL-style models."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -525,7 +747,8 @@ class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
}
|
||||
|
||||
|
||||
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
""" Basic image-only MultiModalProcessor for InternVL-style models."""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -614,6 +837,38 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
|
||||
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
"""InternVL ProcessingInfo extended for video processing"""
|
||||
|
||||
@property
|
||||
def supports_video(self):
|
||||
return self.get_hf_processor().supports_video
|
||||
|
||||
def get_supported_mm_limits(self):
|
||||
video_limit = {"video": None} if self.supports_video else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit}
|
||||
|
||||
def get_video_token(self) -> Optional[str]:
|
||||
text_model_type = self.get_hf_config().get_text_config().model_type
|
||||
if text_model_type == "qwen2":
|
||||
return "<|video_pad|>"
|
||||
return None
|
||||
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = (seq_len -
|
||||
max_image_tokens) // processor.num_image_token
|
||||
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
@ -630,6 +885,8 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
if dynamic_image_size is not None:
|
||||
kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
kwargs["video_token"] = self.get_video_token()
|
||||
|
||||
return self.ctx.init_processor(
|
||||
InternVLProcessor,
|
||||
config=self.get_hf_config(),
|
||||
@ -638,6 +895,121 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
)
|
||||
|
||||
|
||||
class InternVLDummyInputsBuilder(
|
||||
BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]):
|
||||
"""InternVL DummyInputsBuilder extended for video support"""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
|
||||
mm_counts=mm_counts)
|
||||
if self.info.supports_video:
|
||||
config = self.info.get_hf_config()
|
||||
image_size: int = config.vision_config.image_size
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
dummy_video = {
|
||||
"video":
|
||||
self._get_dummy_videos(width=image_size,
|
||||
height=image_size,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos)
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
return {**dummy_image, **dummy_video}
|
||||
|
||||
|
||||
class InternVLMultiModalProcessor(
|
||||
BaseInternVLMultiModalProcessor[InternVLProcessingInfo]):
|
||||
"""InternVL MultiModalProcessor extended for video support"""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
processed_outputs = super()._call_hf_processor(prompt, mm_data,
|
||||
mm_kwargs)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
if self.info.supports_video and (
|
||||
video_token_id := hf_processor.video_token_id) is not None:
|
||||
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs,
|
||||
hf_processor_mm_kwargs)
|
||||
if self.info.supports_video:
|
||||
video_num_patches = hf_inputs.get("video_num_patches",
|
||||
torch.empty(0))
|
||||
num_videos = len(video_num_patches)
|
||||
video_fields = dict(
|
||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_patches),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
video_token_id=MultiModalFieldConfig.shared(
|
||||
"video", num_videos),
|
||||
)
|
||||
else:
|
||||
video_fields = {}
|
||||
|
||||
return image_fields | video_fields
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
prompt_repl: list[PromptUpdate] = super()._get_prompt_updates(
|
||||
mm_items, hf_processor_mm_kwargs, out_mm_kwargs)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "video_num_patches" in out_mm_kwargs:
|
||||
video_num_patches = out_mm_kwargs["video_num_patches"]
|
||||
assert isinstance(video_num_patches, torch.Tensor)
|
||||
video_num_patches = video_num_patches.tolist()
|
||||
else:
|
||||
video_num_patches = []
|
||||
|
||||
def get_video_replacement_internvl(item_idx: int):
|
||||
feature_size = hf_processor.num_image_token
|
||||
num_patches = video_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return hf_processor.get_video_repl(
|
||||
feature_size,
|
||||
num_patches,
|
||||
video_context_token=hf_processor.video_token)
|
||||
|
||||
if self.info.supports_video:
|
||||
prompt_repl.append(
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target="<video>",
|
||||
replacement=get_video_replacement_internvl,
|
||||
))
|
||||
return prompt_repl
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
InternVLMultiModalProcessor,
|
||||
info=InternVLProcessingInfo,
|
||||
@ -681,6 +1053,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.video_context_token_id = None
|
||||
|
||||
self.visual_token_mask = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
@ -825,10 +1199,55 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]:
|
||||
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||
video_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat_video is None and video_embeds is None:
|
||||
return None
|
||||
|
||||
if video_embeds is not None:
|
||||
if not isinstance(video_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
data=flatten_bn(video_embeds),
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
assert isinstance(video_token_id, torch.Tensor)
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}")
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}")
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||
concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
|
||||
return InternVLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_flat=self._validate_pixel_values(
|
||||
pixel_values_flat_video),
|
||||
num_patches=video_num_patches,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||
image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@ -840,8 +1259,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
# Only one image in the current batch
|
||||
if len(num_patches) == 1:
|
||||
return image_embeds.view(
|
||||
-1, self.config.text_config.hidden_size).unsqueeze(0)
|
||||
return (image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size), )
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the size of each embedding.
|
||||
@ -853,8 +1272,26 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
]
|
||||
return image_embeds.split(image_feature_sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values_flat",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",
|
||||
) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
if self.is_mono:
|
||||
assert self.img_context_token_id is not None
|
||||
self.visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
else:
|
||||
@ -865,11 +1302,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -878,13 +1332,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.img_context_token_id is not None
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.img_context_token_id,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
@ -22,9 +22,10 @@ from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
||||
InternVLChatModel, InternVLDummyInputsBuilder,
|
||||
InternVLMultiModalProcessor)
|
||||
from .internvl import (BaseInternVLDummyInputsBuilder,
|
||||
BaseInternVLMultiModalProcessor,
|
||||
BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
||||
InternVLChatModel)
|
||||
|
||||
IMG_PAD = "<|vision_pad|>"
|
||||
|
||||
@ -84,7 +85,8 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||
)
|
||||
|
||||
|
||||
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
|
||||
):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -110,7 +112,8 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||
}
|
||||
|
||||
|
||||
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
class NVLMMultiModalProcessor(
|
||||
BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user