mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:35:01 +08:00
954 lines
31 KiB
Python
954 lines
31 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
|
|
# --------------------------------------------------------
|
|
# SkyworkR1V
|
|
# Copyright (c) 2025 Skywork
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# --------------------------------------------------------
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import Annotated, Literal, TypeAlias
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
from vllm.model_executor.models.intern_vit import (
|
|
InternVisionModel,
|
|
InternVisionPatchModel,
|
|
)
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.image import convert_image_mode
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
ImageEmbeddingItems,
|
|
ImageProcessorItems,
|
|
ImageSize,
|
|
MultiModalDataItems,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
|
|
|
IMG_START = "<img>"
|
|
IMG_END = "</img>"
|
|
IMG_CONTEXT = "<IMG_CONTEXT>"
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
|
|
|
|
class SkyworkR1VImagePixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- bnp: Batch size * number of images * (1 + num_patches)
|
|
- c: Number of channels (3)
|
|
- h: Height
|
|
- w: Width
|
|
- bn: Batch size * number of images
|
|
"""
|
|
|
|
type: Literal["pixel_values"] = "pixel_values"
|
|
|
|
pixel_values_flat: Annotated[
|
|
torch.Tensor,
|
|
TensorShape("bnp", 3, "h", "w"),
|
|
]
|
|
|
|
num_patches: Annotated[
|
|
torch.Tensor,
|
|
TensorShape("bn"),
|
|
]
|
|
|
|
|
|
class SkyworkR1VImageEmbeddingInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- ni: Number of images
|
|
- ifs: Image feature size
|
|
- hs: Hidden size (must match the hidden size of language model
|
|
backbone)
|
|
"""
|
|
|
|
type: Literal["image_embeds"] = "image_embeds"
|
|
|
|
data: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("ni", "ifs", "hs"),
|
|
]
|
|
|
|
|
|
SkyworkR1VImageInputs: TypeAlias = (
|
|
SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs
|
|
)
|
|
|
|
|
|
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
|
|
def build_transform(input_size: int):
|
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
|
return T.Compose(
|
|
[
|
|
T.Lambda(lambda img: convert_image_mode(img, "RGB")),
|
|
T.Resize(
|
|
(input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
|
|
),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=MEAN, std=STD),
|
|
]
|
|
)
|
|
|
|
|
|
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
|
|
def find_closest_aspect_ratio(
|
|
aspect_ratio: float,
|
|
target_ratios: list[tuple[int, int]],
|
|
*,
|
|
width: int,
|
|
height: int,
|
|
image_size: int,
|
|
) -> tuple[int, int]:
|
|
best_ratio_diff = float("inf")
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
if ratio_diff < best_ratio_diff:
|
|
best_ratio_diff = ratio_diff
|
|
best_ratio = ratio
|
|
elif ratio_diff == best_ratio_diff:
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def resolve_skyworkr1v_min_max_num(
|
|
*,
|
|
min_dynamic_patch: int,
|
|
max_dynamic_patch: int,
|
|
dynamic_image_size: bool,
|
|
use_thumbnail: bool,
|
|
) -> tuple[int, int]:
|
|
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
|
|
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
|
|
|
if use_thumbnail and max_dynamic_patch != 1:
|
|
max_dynamic_patch += 1
|
|
|
|
return min_dynamic_patch, max_dynamic_patch
|
|
|
|
|
|
def get_skyworkr1v_target_ratios(
|
|
min_num: int,
|
|
max_num: int,
|
|
) -> list[tuple[int, int]]:
|
|
target_ratios = {
|
|
(i, j)
|
|
for n in range(min_num, max_num + 1)
|
|
for i in range(1, n + 1)
|
|
for j in range(1, n + 1)
|
|
if min_num <= i * j <= max_num
|
|
}
|
|
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
|
|
def calculate_skyworkr1v_targets(
|
|
*,
|
|
orig_width: int,
|
|
orig_height: int,
|
|
target_ratios: list[tuple[int, int]],
|
|
image_size: int,
|
|
use_thumbnail: bool,
|
|
) -> tuple[int, int, int]:
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
# find the closest aspect ratio to the target
|
|
target_aspect_ratio = find_closest_aspect_ratio(
|
|
aspect_ratio,
|
|
target_ratios,
|
|
width=orig_width,
|
|
height=orig_height,
|
|
image_size=image_size,
|
|
)
|
|
|
|
# calculate the target width and height
|
|
target_width = image_size * target_aspect_ratio[0]
|
|
target_height = image_size * target_aspect_ratio[1]
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
|
|
# add thumbnail image if num_blocks != 1
|
|
if use_thumbnail and blocks != 1:
|
|
blocks += 1
|
|
|
|
return blocks, target_width, target_height
|
|
|
|
|
|
def dynamic_preprocess_skyworkr1v(
|
|
image: Image.Image,
|
|
*,
|
|
target_ratios: list[tuple[int, int]],
|
|
image_size: int,
|
|
use_thumbnail: bool,
|
|
) -> list[Image.Image]:
|
|
orig_width, orig_height = image.size
|
|
|
|
# calculate the number of blocks without thumbnail
|
|
blocks, target_width, target_height = calculate_skyworkr1v_targets(
|
|
orig_width=orig_width,
|
|
orig_height=orig_height,
|
|
target_ratios=target_ratios,
|
|
image_size=image_size,
|
|
use_thumbnail=False,
|
|
)
|
|
|
|
# resize the image
|
|
resized_img = image.resize((target_width, target_height))
|
|
processed_images = []
|
|
for i in range(blocks):
|
|
box = (
|
|
(i % (target_width // image_size)) * image_size,
|
|
(i // (target_width // image_size)) * image_size,
|
|
((i % (target_width // image_size)) + 1) * image_size,
|
|
((i // (target_width // image_size)) + 1) * image_size,
|
|
)
|
|
# split the image
|
|
split_img = resized_img.crop(box)
|
|
processed_images.append(split_img)
|
|
|
|
assert len(processed_images) == blocks
|
|
|
|
if use_thumbnail and len(processed_images) != 1:
|
|
thumbnail_img = image.resize((image_size, image_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return processed_images
|
|
|
|
|
|
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B
|
|
def image_to_pixel_values_skyworkr1v(
|
|
image: Image.Image,
|
|
*,
|
|
input_size: int,
|
|
min_num: int,
|
|
max_num: int,
|
|
use_thumbnail: bool,
|
|
) -> torch.Tensor:
|
|
target_ratios = get_skyworkr1v_target_ratios(min_num, max_num)
|
|
|
|
transform = build_transform(input_size=input_size)
|
|
images = dynamic_preprocess_skyworkr1v(
|
|
image,
|
|
target_ratios=target_ratios,
|
|
image_size=input_size,
|
|
use_thumbnail=use_thumbnail,
|
|
)
|
|
|
|
pixel_values = torch.stack([transform(image) for image in images])
|
|
return pixel_values
|
|
|
|
|
|
class SkyworkR1VProcessor:
|
|
"""
|
|
This model doesn't define its own HF processor,
|
|
so we implement our own one here.
|
|
|
|
The code to insert image tokens is based on:
|
|
https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
tokenizer: AnyTokenizer,
|
|
*,
|
|
min_dynamic_patch: int | None = None,
|
|
max_dynamic_patch: int | None = None,
|
|
dynamic_image_size: bool | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.tokenizer = tokenizer
|
|
|
|
image_size: int = config.vision_config.image_size
|
|
patch_size: int = config.vision_config.patch_size
|
|
|
|
if min_dynamic_patch is None:
|
|
min_dynamic_patch = config.min_dynamic_patch
|
|
assert isinstance(min_dynamic_patch, int)
|
|
|
|
if max_dynamic_patch is None:
|
|
max_dynamic_patch = config.max_dynamic_patch
|
|
assert isinstance(max_dynamic_patch, int)
|
|
|
|
if dynamic_image_size is None:
|
|
dynamic_image_size = config.dynamic_image_size
|
|
assert isinstance(dynamic_image_size, bool)
|
|
|
|
self.num_image_token = int(
|
|
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
|
)
|
|
self.image_size = image_size
|
|
self.min_dynamic_patch = min_dynamic_patch
|
|
self.max_dynamic_patch = max_dynamic_patch
|
|
self.dynamic_image_size = dynamic_image_size
|
|
self.use_thumbnail: bool = config.use_thumbnail
|
|
|
|
@property
|
|
def image_token_id(self) -> int:
|
|
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
|
|
|
def get_image_repl(
|
|
self,
|
|
feature_size: int,
|
|
num_patches: int | None,
|
|
) -> PromptUpdateDetails[str]:
|
|
repl_features = IMG_CONTEXT * feature_size
|
|
repl_full = IMG_START + repl_features + IMG_END
|
|
|
|
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
|
|
|
def resolve_min_max_num(
|
|
self,
|
|
*,
|
|
min_dynamic_patch: int | None = None,
|
|
max_dynamic_patch: int | None = None,
|
|
dynamic_image_size: bool | None = None,
|
|
use_thumbnail: bool | None = None,
|
|
) -> tuple[int, int]:
|
|
min_dynamic_patch = (
|
|
self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
|
|
)
|
|
max_dynamic_patch = (
|
|
self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
|
|
)
|
|
dynamic_image_size = (
|
|
self.dynamic_image_size
|
|
if dynamic_image_size is None
|
|
else dynamic_image_size
|
|
)
|
|
use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
|
|
|
|
return resolve_skyworkr1v_min_max_num(
|
|
min_dynamic_patch=min_dynamic_patch,
|
|
max_dynamic_patch=max_dynamic_patch,
|
|
dynamic_image_size=dynamic_image_size,
|
|
use_thumbnail=use_thumbnail,
|
|
)
|
|
|
|
def resolve_target_ratios(
|
|
self,
|
|
*,
|
|
min_dynamic_patch: int | None = None,
|
|
max_dynamic_patch: int | None = None,
|
|
dynamic_image_size: bool | None = None,
|
|
use_thumbnail: bool | None = None,
|
|
) -> list[tuple[int, int]]:
|
|
min_num, max_num = self.resolve_min_max_num(
|
|
min_dynamic_patch=min_dynamic_patch,
|
|
max_dynamic_patch=max_dynamic_patch,
|
|
dynamic_image_size=dynamic_image_size,
|
|
use_thumbnail=use_thumbnail,
|
|
)
|
|
|
|
return get_skyworkr1v_target_ratios(min_num, max_num)
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
target_ratios = self.resolve_target_ratios(
|
|
use_thumbnail=False, # Applied in calculate_targets
|
|
)
|
|
|
|
num_patches, _, _ = calculate_skyworkr1v_targets(
|
|
orig_width=image_width,
|
|
orig_height=image_height,
|
|
image_size=self.image_size,
|
|
target_ratios=target_ratios,
|
|
use_thumbnail=self.use_thumbnail,
|
|
)
|
|
|
|
return num_patches * self.num_image_token
|
|
|
|
def _images_to_pixel_values_lst(
|
|
self,
|
|
images: list[Image.Image],
|
|
min_dynamic_patch: int | None = None,
|
|
max_dynamic_patch: int | None = None,
|
|
dynamic_image_size: bool | None = None,
|
|
) -> list[torch.Tensor]:
|
|
min_num, max_num = self.resolve_min_max_num(
|
|
min_dynamic_patch=min_dynamic_patch,
|
|
max_dynamic_patch=max_dynamic_patch,
|
|
dynamic_image_size=dynamic_image_size,
|
|
use_thumbnail=False, # Applied in image_to_pixel_values
|
|
)
|
|
|
|
return [
|
|
image_to_pixel_values_skyworkr1v(
|
|
image,
|
|
input_size=self.image_size,
|
|
min_num=min_num,
|
|
max_num=max_num,
|
|
use_thumbnail=self.use_thumbnail,
|
|
)
|
|
for image in images
|
|
]
|
|
|
|
def __call__(
|
|
self,
|
|
text: str | list[str] | None = None,
|
|
images: Image.Image | list[Image.Image] | None = None,
|
|
min_dynamic_patch: int | None = None,
|
|
max_dynamic_patch: int | None = None,
|
|
dynamic_image_size: bool | None = None,
|
|
return_tensors: str | TensorType | None = None,
|
|
) -> BatchFeature:
|
|
if text is None:
|
|
text = []
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
if images is None:
|
|
images = []
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
|
|
if len(images) == 0:
|
|
image_inputs = {}
|
|
else:
|
|
pixel_values_lst = self._images_to_pixel_values_lst(
|
|
images,
|
|
min_dynamic_patch=min_dynamic_patch,
|
|
max_dynamic_patch=max_dynamic_patch,
|
|
dynamic_image_size=dynamic_image_size,
|
|
)
|
|
image_inputs = {
|
|
"pixel_values_flat": torch.cat(pixel_values_lst),
|
|
"image_num_patches": torch.tensor(
|
|
[len(item) for item in pixel_values_lst]
|
|
),
|
|
}
|
|
|
|
for pixel_values in pixel_values_lst:
|
|
num_patches = pixel_values.shape[0]
|
|
feature_size = num_patches * self.num_image_token
|
|
|
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
|
|
|
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
|
|
|
text_inputs = self.tokenizer(text)
|
|
|
|
combined_outputs = {**text_inputs, **image_inputs}
|
|
|
|
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
|
|
|
|
|
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
|
|
return self.ctx.init_processor(
|
|
SkyworkR1VProcessor,
|
|
config=self.get_hf_config(),
|
|
tokenizer=self.get_tokenizer(),
|
|
**kwargs,
|
|
)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"image": None}
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
processor: SkyworkR1VProcessor | None,
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
return processor.get_num_image_tokens(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
)
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
processor = self.get_hf_processor()
|
|
|
|
base_size = processor.image_size
|
|
target_ratios = processor.resolve_target_ratios()
|
|
|
|
largest_feature_size, largest_feature_pinpoint = 0, None
|
|
for wr, hr in target_ratios:
|
|
width, height = base_size * wr, base_size * hr
|
|
|
|
feat_size = self.get_num_image_tokens(
|
|
image_width=width,
|
|
image_height=height,
|
|
processor=processor,
|
|
)
|
|
if feat_size > largest_feature_size:
|
|
largest_feature_size = feat_size
|
|
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
|
|
|
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
|
raise ValueError("Cannot have a largest feature size of 0!")
|
|
|
|
return largest_feature_pinpoint
|
|
|
|
|
|
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
return "<image>" * num_images
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
target_width, target_height = self.info.get_image_size_with_most_features()
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
image_overrides = mm_options.get("image") if mm_options else None
|
|
|
|
return {
|
|
"image": self._get_dummy_images(
|
|
width=target_width,
|
|
height=target_height,
|
|
num_images=num_images,
|
|
overrides=image_overrides,
|
|
)
|
|
}
|
|
|
|
|
|
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
|
image_token_id = hf_processor.image_token_id
|
|
|
|
# Since there may be extra tokens in the feature placeholders,
|
|
# we need to pass the image token ID to the model to select the
|
|
# tokens to merge from the vision encoder outputs
|
|
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
|
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
|
num_images = len(image_num_patches)
|
|
|
|
return dict(
|
|
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", image_num_patches
|
|
),
|
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
|
|
out_mm_data = out_mm_kwargs.get_data()
|
|
if "image_num_patches" in out_mm_data:
|
|
image_num_patches = out_mm_data["image_num_patches"]
|
|
assert isinstance(image_num_patches, torch.Tensor)
|
|
image_num_patches = image_num_patches.tolist()
|
|
elif "image_embeds" in out_mm_data:
|
|
# TODO: Use image size information in dictionary embedding inputs
|
|
# to compute num_patches (similar to Qwen2-VL)
|
|
image_num_patches = [None] * len(out_mm_data["image_embeds"])
|
|
else:
|
|
image_num_patches = []
|
|
|
|
def get_replacement_skyworkr1v(item_idx: int):
|
|
images = mm_items.get_items(
|
|
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
|
)
|
|
|
|
if isinstance(images, ImageEmbeddingItems):
|
|
feature_size = images.get_feature_size(item_idx)
|
|
else:
|
|
image_size = images.get_image_size(item_idx)
|
|
feature_size = self.info.get_num_image_tokens(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
processor=hf_processor,
|
|
)
|
|
|
|
num_patches = image_num_patches[item_idx]
|
|
if num_patches is not None:
|
|
assert isinstance(num_patches, int)
|
|
|
|
return hf_processor.get_image_repl(feature_size, num_patches)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target="<image>",
|
|
replacement=get_replacement_skyworkr1v,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
SkyworkR1VMultiModalProcessor,
|
|
info=SkyworkR1VProcessingInfo,
|
|
dummy_inputs=SkyworkR1VDummyInputsBuilder,
|
|
)
|
|
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|
merge_by_field_config = True
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("image"):
|
|
return "<image>"
|
|
|
|
raise ValueError("Only image modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self._patch_quant_config(config, quant_config)
|
|
|
|
image_size = config.force_image_size or config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
self.patch_size = patch_size
|
|
self.num_image_token = int(
|
|
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
|
)
|
|
self.downsample_ratio = config.downsample_ratio
|
|
self.ps_version = config.ps_version
|
|
|
|
self.llm_arch_name = config.text_config.architectures[0]
|
|
self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM"
|
|
self.vision_model = self._init_vision_model(
|
|
config,
|
|
quant_config=quant_config,
|
|
is_mono=self.is_mono,
|
|
prefix=maybe_prefix(prefix, "vision_model"),
|
|
)
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.mlp1 = self._init_mlp1(
|
|
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
|
|
)
|
|
|
|
self.img_context_token_id = None
|
|
self.visual_token_mask = None
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _patch_quant_config(
|
|
self, config: PretrainedConfig, quant_config: QuantizationConfig
|
|
):
|
|
# the awq models from OpenGVLab missing `modules_to_not_convert`
|
|
# patch the quant_config to add `modules_to_not_convert` back
|
|
if isinstance(quant_config, AWQConfig):
|
|
text_config = config.text_config
|
|
llm_quant_config = getattr(text_config, "quantization_config", None)
|
|
if (not quant_config.modules_to_not_convert) and (
|
|
llm_quant_config is not None
|
|
):
|
|
quant_config.modules_to_not_convert.append("vision_model")
|
|
|
|
def _init_vision_model(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None,
|
|
*,
|
|
is_mono: bool,
|
|
prefix: str,
|
|
):
|
|
if not is_mono:
|
|
vision_feature_layer = config.select_layer
|
|
if vision_feature_layer < 0:
|
|
num_hidden_layers = (
|
|
config.vision_config.num_hidden_layers + vision_feature_layer + 1
|
|
)
|
|
else:
|
|
num_hidden_layers = vision_feature_layer + 1
|
|
|
|
return InternVisionModel(
|
|
config.vision_config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
prefix=prefix,
|
|
)
|
|
else:
|
|
return InternVisionPatchModel(config.vision_config)
|
|
|
|
def _init_mlp1(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
vit_hidden_size = config.vision_config.hidden_size
|
|
llm_hidden_size = config.text_config.hidden_size
|
|
|
|
return nn.Sequential(
|
|
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
|
ReplicatedLinear(
|
|
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
|
llm_hidden_size,
|
|
return_bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.1",
|
|
),
|
|
nn.GELU(),
|
|
ReplicatedLinear(
|
|
llm_hidden_size,
|
|
llm_hidden_size,
|
|
return_bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.3",
|
|
),
|
|
)
|
|
|
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
|
n, w, h, c = x.size()
|
|
# N, W, H, C --> N, W, H * scale, C // scale
|
|
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
|
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
x = x.view(
|
|
n,
|
|
int(h * scale_factor),
|
|
int(w * scale_factor),
|
|
int(c / (scale_factor * scale_factor)),
|
|
)
|
|
if self.ps_version == "v1":
|
|
pass
|
|
else:
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
return x
|
|
|
|
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
vit_embeds = self.vision_model(pixel_values=pixel_values)
|
|
vit_embeds = vit_embeds[:, 1:, :]
|
|
|
|
h = w = int(vit_embeds.shape[1] ** 0.5)
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
|
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
|
vit_embeds = self.mlp1(vit_embeds)
|
|
return vit_embeds
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object
|
|
) -> SkyworkR1VImageInputs | None:
|
|
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
|
image_num_patches = kwargs.pop("image_num_patches", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
if pixel_values_flat is None and image_embeds is None:
|
|
return None
|
|
|
|
if image_embeds is not None:
|
|
return SkyworkR1VImageEmbeddingInputs(
|
|
type="image_embeds",
|
|
data=image_embeds,
|
|
)
|
|
|
|
image_token_id = kwargs["image_token_id"]
|
|
if isinstance(image_token_id, torch.Tensor):
|
|
image_token_id = image_token_id.flatten().unique().item()
|
|
|
|
assert isinstance(image_token_id, int)
|
|
self.img_context_token_id = image_token_id
|
|
|
|
if pixel_values_flat is not None:
|
|
return SkyworkR1VImagePixelInputs(
|
|
type="pixel_values",
|
|
pixel_values_flat=pixel_values_flat,
|
|
num_patches=image_num_patches,
|
|
resolve_bindings={
|
|
"h": self.config.vision_config.image_size,
|
|
"w": self.config.vision_config.image_size,
|
|
},
|
|
)
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: SkyworkR1VImageInputs,
|
|
) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
|
|
if image_input["type"] == "image_embeds":
|
|
return image_input["data"]
|
|
|
|
assert self.vision_model is not None
|
|
|
|
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
|
|
|
num_patches = image_input["num_patches"]
|
|
|
|
# Only one image in the current batch
|
|
if len(num_patches) == 1:
|
|
return image_embeds.view(-1, self.config.text_config.hidden_size).unsqueeze(
|
|
0
|
|
)
|
|
|
|
# NOTE: Image embeddings are split into separate tensors for each image
|
|
# by the size of each embedding.
|
|
feature_size = image_embeds.shape[1]
|
|
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
|
|
image_feature_sizes = [
|
|
num_patches * feature_size for num_patches in num_patches
|
|
]
|
|
return image_embeds.split(image_feature_sizes)
|
|
|
|
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
|
if self.is_mono:
|
|
self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
|
|
-1, 1
|
|
)
|
|
else:
|
|
self.visual_token_mask = None
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return []
|
|
|
|
return self._process_image_input(image_input)
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
handle_oov_mm_token: bool = False,
|
|
) -> torch.Tensor:
|
|
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
|
self._set_visual_token_mask(input_ids)
|
|
|
|
# This is to satisfy the type checker for each overload
|
|
if multimodal_embeddings is None or is_multimodal is None:
|
|
return super().get_input_embeddings(input_ids)
|
|
|
|
return super().get_input_embeddings(
|
|
input_ids,
|
|
multimodal_embeddings=multimodal_embeddings,
|
|
is_multimodal=is_multimodal,
|
|
handle_oov_mm_token=handle_oov_mm_token,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
input_ids = None
|
|
inputs_embeds = None
|
|
|
|
forward_kwargs = {
|
|
"input_ids": input_ids,
|
|
"positions": positions,
|
|
"intermediate_tensors": intermediate_tensors,
|
|
"inputs_embeds": inputs_embeds,
|
|
}
|
|
|
|
# Only required if the model is mono-architecture
|
|
if self.visual_token_mask is not None:
|
|
forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
|
|
self.visual_token_mask = None
|
|
|
|
hidden_states = self.language_model.model(**forward_kwargs)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
skip_prefixes = [
|
|
"action_embed",
|
|
"temporal_embed",
|
|
"track_embed",
|
|
"track_embed_decoder",
|
|
"box_token",
|
|
"cg_criterion",
|
|
"cg_model",
|
|
"loc_encoder",
|
|
"loc_decoder",
|
|
"sam",
|
|
"temporal_token",
|
|
"track_token",
|
|
]
|
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
|
return loader.load_weights(weights)
|