mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 13:34:32 +08:00
Support Intern-S1 (#21628)
Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Your Name <you@example.com> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7728dd77bb
commit
875af38e01
@ -593,6 +593,7 @@ Specified using `--task generate`.
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> | `internlm/Intern-S1`, etc. | | ✅︎ | ✅︎ |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
|
||||
|
||||
@ -468,6 +468,37 @@ def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Intern-S1
|
||||
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "internlm/Intern-S1"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
placeholder = "<IMG_CONTEXT>"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"{placeholder}\n{question}"}]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "OpenGVLab/InternVL3-2B"
|
||||
@ -1303,6 +1334,7 @@ model_example_map = {
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
|
||||
"idefics3": run_idefics3,
|
||||
"interns1": run_interns1,
|
||||
"internvl_chat": run_internvl,
|
||||
"nemotron_vl": run_nemotron_vl,
|
||||
"keye_vl": run_keye_vl,
|
||||
|
||||
@ -253,6 +253,33 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "internlm/Intern-S1"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <IMG_CONTEXT>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
@ -946,6 +973,7 @@ model_example_map = {
|
||||
"gemma3": load_gemma3,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"idefics3": load_idefics3,
|
||||
"interns1": load_interns1,
|
||||
"internvl_chat": load_internvl,
|
||||
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
|
||||
"keye_vl": load_keye_vl,
|
||||
|
||||
@ -381,6 +381,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B",
|
||||
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
|
||||
|
||||
711
vllm/model_executor/models/interns1.py
Normal file
711
vllm/model_executor/models/interns1.py
Normal file
@ -0,0 +1,711 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# --------------------------------------------------------
|
||||
# InternS1
|
||||
# Copyright (c) 2025 Shanghai AI Lab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import InternVLProcessor, PretrainedConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import (
|
||||
GotOcr2ImageProcessorFast)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.interns1_vit import InternS1VisionModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class InternS1MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size *
|
||||
int(1 / config.downsample_ratio)**2)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size *
|
||||
int(1 / config.downsample_ratio)**2,
|
||||
config.text_config.hidden_size)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size,
|
||||
config.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.layer_norm(image_features)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternS1ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
||||
"""
|
||||
|
||||
|
||||
class InternS1ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
InternS1ImageInputs = Union[InternS1ImagePixelInputs,
|
||||
InternS1ImageEmbeddingInputs]
|
||||
|
||||
|
||||
class InternS1VideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_video * num_frames, num_channels, height, width)`
|
||||
"""
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
|
||||
class InternS1VideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
InternS1VideoInputs = Union[InternS1VideoPixelInputs,
|
||||
InternS1VideoEmbeddingInputs]
|
||||
|
||||
|
||||
def resolve_interns1_min_max_num(
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: bool,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
if use_thumbnail and max_dynamic_patch != 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
return min_dynamic_patch, max_dynamic_patch
|
||||
|
||||
|
||||
def get_interns1_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = {(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
class InternS1ProcessingInfo(BaseProcessingInfo):
|
||||
"""Basic image-only ProcessingInfo for InternS1-style models."""
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
|
||||
return self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional['GotOcr2ImageProcessorFast'] = None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor().image_processor
|
||||
|
||||
if not isinstance(processor, GotOcr2ImageProcessorFast):
|
||||
raise ValueError(f'GotOcr2ImageProcessorFast is expected but got '
|
||||
f'{type(processor)}')
|
||||
num_image_patches = processor.get_number_of_image_tokens(
|
||||
image_height, image_width, images_kwargs=dict())
|
||||
num_image_tokens = self.get_hf_processor(
|
||||
).image_seq_length * num_image_patches
|
||||
return num_image_tokens
|
||||
|
||||
def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None):
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
min_dynamic_patch = image_processor.min_patches
|
||||
max_dynamic_patch = image_processor.max_patches
|
||||
# HF format's InternVL processor uses `crop_to_patches` which is
|
||||
# equivalent to `use_thumbnail` in original format.
|
||||
use_thumbnail = image_processor.crop_to_patches
|
||||
dynamic_image_size = True
|
||||
min_num, max_num = resolve_interns1_min_max_num(
|
||||
min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
return get_interns1_target_ratios(min_num, max_num)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
base_height, base_width = hf_config.vision_config.image_size
|
||||
target_ratios = self.resolve_target_ratios()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for wr, hr in target_ratios:
|
||||
width, height = base_width * wr, base_height * hr
|
||||
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
processor=processor.image_processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
assert not (largest_feature_size == 0 or largest_feature_pinpoint
|
||||
is None), ("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=processor.image_processor,
|
||||
)
|
||||
|
||||
|
||||
class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
||||
):
|
||||
"""Basic image-only DummyInputsBuilder for InternS1-style models."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
image_token = self.info.get_hf_processor().image_token
|
||||
|
||||
return image_token * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
|
||||
class InternS1MultiModalProcessor(
|
||||
BaseMultiModalProcessor[InternS1ProcessingInfo]):
|
||||
""" Basic image-only MultiModalProcessor for InternS1-style models."""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
image_token_id = hf_processor.image_token_id
|
||||
|
||||
# Since there may be extra tokens in the feature placeholders,
|
||||
# we need to pass the image token ID to the model to select the
|
||||
# tokens to merge from the vision encoder outputs
|
||||
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
||||
images = mm_data.get('images', None)
|
||||
image_processor = self.info.get_hf_processor().image_processor
|
||||
if images is not None:
|
||||
image_inputs = image_processor(images=images)
|
||||
image_num_patches = image_inputs.pop("num_patches")
|
||||
if not isinstance(image_num_patches, list):
|
||||
raise ValueError(
|
||||
f'num_patches is supposed to be list, but got '
|
||||
f'{type(image_num_patches)}')
|
||||
image_num_patches = torch.tensor(image_num_patches)
|
||||
processed_outputs['image_num_patches'] = image_num_patches
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
num_images = len(image_num_patches)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
img_context_token = hf_processor.image_token
|
||||
start_image_token = hf_processor.start_image_token
|
||||
end_image_token = hf_processor.end_image_token
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor.image_processor,
|
||||
)
|
||||
|
||||
repl_features = img_context_token * feature_size
|
||||
repl_full = start_image_token + repl_features + end_image_token
|
||||
return PromptUpdateDetails.select_text(repl_full,
|
||||
img_context_token)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=img_context_token,
|
||||
replacement=get_replacement,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
InternS1MultiModalProcessor,
|
||||
info=InternS1ProcessingInfo,
|
||||
dummy_inputs=InternS1DummyInputsBuilder)
|
||||
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsLoRA):
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
# transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator
|
||||
# refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116
|
||||
if modality.startswith("image"):
|
||||
return '<IMG_CONTEXT>'
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
image_size = config.vision_config.image_size[0]
|
||||
patch_size = config.vision_config.patch_size[0]
|
||||
self.patch_size = patch_size
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.vision_tower = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.multi_modal_projector = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.video_context_token_id = None
|
||||
|
||||
self.visual_token_mask = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
prefix: str,
|
||||
):
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers
|
||||
return InternS1VisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
return InternS1MultiModalProjector(config)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)))
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
vit_embeds = self.vision_tower(pixel_values=pixel_values)
|
||||
vit_embeds = vit_embeds[:, 1:, :]
|
||||
|
||||
h = w = int(vit_embeds.shape[1]**0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds,
|
||||
scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
|
||||
vit_embeds.shape[-1])
|
||||
|
||||
vit_embeds = self.multi_modal_projector(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
h, w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f" per patch is {expected_expr}. "
|
||||
f"You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternS1ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return InternS1ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
return InternS1ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=image_num_patches,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
|
||||
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||
video_embeds = kwargs.pop("video_embeds", None)
|
||||
|
||||
if pixel_values_flat_video is None and video_embeds is None:
|
||||
return None
|
||||
|
||||
if video_embeds is not None:
|
||||
if not isinstance(video_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
|
||||
return InternS1ImageEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
data=flatten_bn(video_embeds),
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
assert isinstance(video_token_id, torch.Tensor)
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}")
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}")
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||
concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
|
||||
return InternS1VideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values=self._validate_pixel_values(
|
||||
pixel_values_flat_video),
|
||||
num_patches=video_num_patches,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
|
||||
image_embeds = self.extract_feature(image_input["pixel_values"])
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
# Only one image in the current batch
|
||||
if len(num_patches) == 1:
|
||||
return (image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size), )
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the size of each embedding.
|
||||
feature_size = image_embeds.shape[1]
|
||||
image_embeds = image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size)
|
||||
image_feature_sizes = [
|
||||
num_patches * feature_size for num_patches in num_patches
|
||||
]
|
||||
return image_embeds.split(image_feature_sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",
|
||||
) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
return None
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
|
||||
hidden_states = self.language_model.model(**forward_kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower")
|
||||
421
vllm/model_executor/models/interns1_vit.py
Normal file
421
vllm/model_executor/models/interns1_vit.py
Normal file
@ -0,0 +1,421 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
'layer_norm': nn.LayerNorm,
|
||||
}
|
||||
|
||||
|
||||
class InternS1VisionPatchEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
|
||||
patch_size[0])
|
||||
patch_shape = (image_size[0] // patch_size[0],
|
||||
image_size[1] // patch_size[1])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
self.patch_shape = patch_shape
|
||||
|
||||
self.projection = nn.Conv2d(num_channels,
|
||||
hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values "
|
||||
"match with the one set in the configuration.")
|
||||
|
||||
embeddings = self.projection(
|
||||
pixel_values.to(self.projection.weight.dtype))
|
||||
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
class InternS1VisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
if config.use_mask_token:
|
||||
self.mask_token = nn.Parameter(
|
||||
torch.zeros(1, 1, config.hidden_size))
|
||||
else:
|
||||
self.mask_token = None
|
||||
self.patch_embeddings = InternS1VisionPatchEmbeddings(config)
|
||||
self.patch_size = config.patch_size
|
||||
self.image_size = (config.image_size if isinstance(
|
||||
config.image_size, Iterable) else
|
||||
(config.image_size, config.image_size))
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
if config.use_absolute_position_embeddings:
|
||||
self.position_embeddings = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
else:
|
||||
self.position_embeddings = None
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
|
||||
width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
""" # noqa: E501
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model
|
||||
# works for dynamic input shapes
|
||||
if not torch.jit.is_tracing(
|
||||
) and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size[0]
|
||||
new_width = width // self.patch_size[1]
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
|
||||
sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
embeddings, (patch_height,
|
||||
patch_width) = self.patch_embeddings(pixel_values)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||||
# replace the masked visual tokens by mask_tokens
|
||||
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||
embeddings = embeddings * (1 - w) + mask_tokens * w
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(
|
||||
embeddings, height, width)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
class InternSdpaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f'embed_dim must be divisible by num_heads '
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.embed_dim,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=config.attention_bias)
|
||||
|
||||
self.qk_normalization = config.use_qk_norm
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
self.k_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
|
||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
q = q.view(B, N, self.num_heads, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.projection_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternS1VisionMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternS1VisionLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attention = self._init_attn(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.mlp = InternS1VisionMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.layernorm_before = NORM2FN[config.norm_type](
|
||||
config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layernorm_after = NORM2FN[config.norm_type](
|
||||
config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
init_values = config.layer_scale_init_value
|
||||
self.lambda_1 = nn.Parameter(init_values *
|
||||
torch.ones(config.hidden_size),
|
||||
requires_grad=True)
|
||||
self.lambda_2 = nn.Parameter(init_values *
|
||||
torch.ones(config.hidden_size),
|
||||
requires_grad=True)
|
||||
|
||||
def _init_attn(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
num_dummy_heads: int,
|
||||
prefix: str = "",
|
||||
):
|
||||
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
hidden_states = hidden_states + self.attention(
|
||||
self.layernorm_before(hidden_states)) * self.lambda_1
|
||||
|
||||
hidden_states = hidden_states + self.mlp(
|
||||
self.layernorm_after(hidden_states)) * self.lambda_2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternS1VisionEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layer = nn.ModuleList([
|
||||
InternS1VisionLayer(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layer:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternS1VisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternS1VisionEmbeddings(config)
|
||||
self.encoder = InternS1VisionEncoder(
|
||||
config=config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
self.layernorm = (nn.Identity() if config.use_mean_pooling else
|
||||
nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps))
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if pixel_values is None and pixel_embeds is None:
|
||||
raise ValueError(
|
||||
'You have to specify pixel_values or pixel_embeds')
|
||||
|
||||
if pixel_embeds is not None:
|
||||
hidden_states = pixel_embeds
|
||||
elif pixel_values is not None:
|
||||
if pixel_values.ndim == 4:
|
||||
hidden_states, _ = self.embeddings(pixel_values)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'wrong pixel_values size: {pixel_values.shape}')
|
||||
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
encoder_outputs = self.layernorm(encoder_outputs)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@ -203,6 +203,7 @@ _MULTIMODAL_MODELS = {
|
||||
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user