Support Intern-S1 (#21628)

Signed-off-by: Roger Wang <hey@rogerw.me>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Lyu Han 2025-07-26 19:14:04 +08:00 committed by GitHub
parent 7728dd77bb
commit 875af38e01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1196 additions and 0 deletions

View File

@ -593,6 +593,7 @@ Specified using `--task generate`.
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> | `internlm/Intern-S1`, etc. | | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |

View File

@ -468,6 +468,37 @@ def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
)
# Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "internlm/Intern-S1"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
placeholder = "<IMG_CONTEXT>"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "OpenGVLab/InternVL3-2B"
@ -1303,6 +1334,7 @@ model_example_map = {
"h2ovl_chat": run_h2ovl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,
"internvl_chat": run_internvl,
"nemotron_vl": run_nemotron_vl,
"keye_vl": run_keye_vl,

View File

@ -253,6 +253,33 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "internlm/Intern-S1"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(
f"Image-{i}: <IMG_CONTEXT>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"
@ -946,6 +973,7 @@ model_example_map = {
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"interns1": load_interns1,
"internvl_chat": load_internvl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
"keye_vl": load_keye_vl,

View File

@ -381,6 +381,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
trust_remote_code=True),
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501

View File

@ -0,0 +1,711 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# --------------------------------------------------------
# InternS1
# Copyright (c) 2025 Shanghai AI Lab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
from transformers import InternVLProcessor, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import (
GotOcr2ImageProcessorFast)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
class InternS1MultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size *
int(1 / config.downsample_ratio)**2)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size *
int(1 / config.downsample_ratio)**2,
config.text_config.hidden_size)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size)
def forward(self, image_features):
hidden_states = self.layer_norm(image_features)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class InternS1ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
class InternS1ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternS1ImageInputs = Union[InternS1ImagePixelInputs,
InternS1ImageEmbeddingInputs]
class InternS1VideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values: torch.Tensor
"""
Shape:
`(batch_size * num_video * num_frames, num_channels, height, width)`
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class InternS1VideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternS1VideoInputs = Union[InternS1VideoPixelInputs,
InternS1VideoEmbeddingInputs]
def resolve_interns1_min_max_num(
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_interns1_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
class InternS1ProcessingInfo(BaseProcessingInfo):
"""Basic image-only ProcessingInfo for InternS1-style models."""
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
return self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional['GotOcr2ImageProcessorFast'] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor().image_processor
if not isinstance(processor, GotOcr2ImageProcessorFast):
raise ValueError(f'GotOcr2ImageProcessorFast is expected but got '
f'{type(processor)}')
num_image_patches = processor.get_number_of_image_tokens(
image_height, image_width, images_kwargs=dict())
num_image_tokens = self.get_hf_processor(
).image_seq_length * num_image_patches
return num_image_tokens
def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None):
image_processor = self.get_hf_processor().image_processor
min_dynamic_patch = image_processor.min_patches
max_dynamic_patch = image_processor.max_patches
# HF format's InternVL processor uses `crop_to_patches` which is
# equivalent to `use_thumbnail` in original format.
use_thumbnail = image_processor.crop_to_patches
dynamic_image_size = True
min_num, max_num = resolve_interns1_min_max_num(
min_dynamic_patch,
max_dynamic_patch,
dynamic_image_size,
use_thumbnail=use_thumbnail)
return get_interns1_target_ratios(min_num, max_num)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
hf_config = self.ctx.get_hf_config()
base_height, base_width = hf_config.vision_config.image_size
target_ratios = self.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_width * wr, base_height * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor.image_processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
assert not (largest_feature_size == 0 or largest_feature_pinpoint
is None), ("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_max_image_tokens(self) -> int:
processor = self.get_hf_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=processor.image_processor,
)
class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
):
"""Basic image-only DummyInputsBuilder for InternS1-style models."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
image_token = self.info.get_hf_processor().image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class InternS1MultiModalProcessor(
BaseMultiModalProcessor[InternS1ProcessingInfo]):
""" Basic image-only MultiModalProcessor for InternS1-style models."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
images = mm_data.get('images', None)
image_processor = self.info.get_hf_processor().image_processor
if images is not None:
image_inputs = image_processor(images=images)
image_num_patches = image_inputs.pop("num_patches")
if not isinstance(image_num_patches, list):
raise ValueError(
f'num_patches is supposed to be list, but got '
f'{type(image_num_patches)}')
image_num_patches = torch.tensor(image_num_patches)
processed_outputs['image_num_patches'] = image_num_patches
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
img_context_token = hf_processor.image_token
start_image_token = hf_processor.start_image_token
end_image_token = hf_processor.end_image_token
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor.image_processor,
)
repl_features = img_context_token * feature_size
repl_full = start_image_token + repl_features + end_image_token
return PromptUpdateDetails.select_text(repl_full,
img_context_token)
return [
PromptReplacement(
modality="image",
target=img_context_token,
replacement=get_replacement,
)
]
@MULTIMODAL_REGISTRY.register_processor(
InternS1MultiModalProcessor,
info=InternS1ProcessingInfo,
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
# transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator
# refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116
if modality.startswith("image"):
return '<IMG_CONTEXT>'
if modality.startswith("video"):
return "<video>"
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
image_size = config.vision_config.image_size[0]
patch_size = config.vision_config.patch_size[0]
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
self.llm_arch_name = config.text_config.architectures[0]
self.vision_tower = self._init_vision_model(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = self._init_mlp1(config)
self.img_context_token_id = None
self.video_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
prefix: str,
):
num_hidden_layers = config.vision_config.num_hidden_layers
return InternS1VisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_tower(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.multi_modal_projector(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h, w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternS1ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternS1ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternS1ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=image_num_patches,
)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
if pixel_values_flat_video is None and video_embeds is None:
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternS1ImageEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
return InternS1VideoPixelInputs(
type="pixel_values_videos",
pixel_values=self._validate_pixel_values(
pixel_values_flat_video),
num_patches=video_num_patches,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_embeds = self.extract_feature(image_input["pixel_values"])
num_patches = image_input["num_patches"]
# Only one image in the current batch
if len(num_patches) == 1:
return (image_embeds.view(-1,
self.config.text_config.hidden_size), )
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
return image_embeds.split(image_feature_sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_flat_video",
) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")

View File

@ -0,0 +1,421 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
}
class InternS1VisionPatchEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
patch_size[0])
patch_shape = (image_size[0] // patch_size[0],
image_size[1] // patch_size[1])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_shape = patch_shape
self.projection = nn.Conv2d(num_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values "
"match with the one set in the configuration.")
embeddings = self.projection(
pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (patch_height, patch_width)
class InternS1VisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
if config.use_mask_token:
self.mask_token = nn.Parameter(
torch.zeros(1, 1, config.hidden_size))
else:
self.mask_token = None
self.patch_embeddings = InternS1VisionPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (config.image_size if isinstance(
config.image_size, Iterable) else
(config.image_size, config.image_size))
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + 1, config.hidden_size))
else:
self.position_embeddings = None
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
""" # noqa: E501
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
# always interpolate when tracing to ensure the exported model
# works for dynamic input shapes
if not torch.jit.is_tracing(
) and num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size[0]
new_width = width // self.patch_size[1]
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings, (patch_height,
patch_width) = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1 - w) + mask_tokens * w
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if self.position_embeddings is not None:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width)
return embeddings, (patch_height, patch_width)
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.qk_normalization = config.use_qk_norm
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.projection_layer(x)
return x
class InternS1VisionMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class InternS1VisionLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.attention = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attention")
self.mlp = InternS1VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layernorm_before = NORM2FN[config.norm_type](
config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = NORM2FN[config.norm_type](
config.hidden_size, eps=config.layer_norm_eps)
init_values = config.layer_scale_init_value
self.lambda_1 = nn.Parameter(init_values *
torch.ones(config.hidden_size),
requires_grad=True)
self.lambda_2 = nn.Parameter(init_values *
torch.ones(config.hidden_size),
requires_grad=True)
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
prefix: str = "",
):
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
):
hidden_states = hidden_states + self.attention(
self.layernorm_before(hidden_states)) * self.lambda_1
hidden_states = hidden_states + self.mlp(
self.layernorm_after(hidden_states)) * self.lambda_2
return hidden_states
class InternS1VisionEncoder(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layer = nn.ModuleList([
InternS1VisionLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layer:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class InternS1VisionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embeddings = InternS1VisionEmbeddings(config)
self.encoder = InternS1VisionEncoder(
config=config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
self.layernorm = (nn.Identity() if config.use_mean_pooling else
nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps))
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states, _ = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
encoder_outputs = self.layernorm(encoder_outputs)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -203,6 +203,7 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),