Cyrus Leung 74704d4553
[Model] Use merge_by_field_config for MM models (O-P) (#26776)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-10-14 09:42:45 +00:00

739 lines
26 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
import regex as re
import torch
import torch.nn as nn
from transformers import (
BatchFeature,
CLIPVisionConfig,
PretrainedConfig,
ProcessorMixin,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
ResolvedPromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
dropout=0.0,
hidden_act="quick_gelu",
hidden_size=1024,
image_size=336,
intermediate_size=4096,
num_attention_heads=16,
num_channels=3,
num_hidden_layers=24,
patch_size=14,
projection_dim=768,
)
def _init_img_processor(
hf_config: PretrainedConfig,
quant_config: QuantizationConfig | None,
prefix: str = "",
) -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get("layer_idx", -2)
# Initialize the CLIP only up to the required feature layer
if layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
else:
num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel(
clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
return img_processor
class Phi3VImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- h: Height of each patch
- w: Width of each patch
"""
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
pixel_values: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape(
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # 'p' may vary across items
]
# Stacked tensor with height and width for each image
image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
class Phi3VImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- f: Image feature size (e.g., number of tokens per image)
- h: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "f", "h"),
]
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
TYPE_FEATURE = self.type_feature
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds)
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if TYPE_FEATURE == "cls_patch":
return img_feature
raise NotImplementedError
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None,
prefix: str = "",
) -> None:
super().__init__()
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
self.img_processor = _init_img_processor(
config, quant_config, prefix=f"{prefix}.img_processor"
)
image_dim_out = config.img_processor["image_dim_out"]
self.num_img_tokens = config.img_processor["num_img_tokens"]
self.image_dim_out = image_dim_out
# global_gn and sub_gn for hd transform, serves as line separator
self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
self.with_learnable_separator = config.embd_layer.get(
"with_learnable_separator", False
)
self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
# with_hd_transform and with_learnable_separator should have same value
assert self.use_hd_transform and self.with_learnable_separator
# 1024 * 4, merge spatial to channel dimension
self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
dim_projection = hidden_size
depth = 2
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
for _ in range(1, depth):
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)
self.type_feature = config.img_processor.get("type_feature", "patch")
def forward(
self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
) -> torch.FloatTensor:
"""
process image and return vision embeddings.
pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
"""
num_images, num_crops, c, h, w = pixel_values.shape
pixel_values = pixel_values.flatten(0, 1)
img_features = self.get_img_features(pixel_values)
img_features = img_features.reshape(
num_images, num_crops, -1, self.image_dim_out
)
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj
def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
assert self.hd_transform_order == "sub_glb", (
f"hd_transform_order `{self.hd_transform_order}` not implemented"
)
if isinstance(self.img_projection, nn.Sequential):
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
else: # It's a single nn.Linear layer
target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype
global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(
global_image_features, 1, 1
)
global_image_features_hd_newline = self.add_image_newline(
global_image_features_hd
)
batch_image_features_proj = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop
# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1 : 1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
sub_image_features, h_crop, w_crop
)
sub_image_features_hd_newline = self.add_image_newline(
sub_image_features_hd
)
# [sub features, separator, global features]
image_embeddings = torch.cat(
[
sub_image_features_hd_newline.squeeze(
0
), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]
)
img_proj = self.img_projection(
image_embeddings.to(target_device, target_dtype)
)
batch_image_features_proj.append(img_proj)
return batch_image_features_proj
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
where h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(
num_images, h_crop, w_crop, H // 2, H // 2, -1
) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(
num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
) # n_img, h_crop*12, w_crop*12, 4096
)
return image_features_hd
def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(
num_images, h, -1, -1
) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat(
[image_features_hd, newline_embeddings], dim=2
).reshape(num_images, -1, hid_dim)
return image_features_hd_newline
class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: ProcessorMixin | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return "".join(image_tokens[:num_images])
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = processed_outputs["input_ids"]
assert isinstance(input_ids, torch.Tensor)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
def get_replacement_phi3v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [_IMAGE_TOKEN_ID] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=image_tokens.__getitem__,
replacement=get_replacement_phi3v,
)
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
return new_update
def _apply_prompt_updates(
self,
token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images
if len(mm_prompt_updates):
tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to
# 1. remove the first bos token
# 2. remove space after each special token
# introduced by the tokenizer
if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
token_ids = token_ids[1:]
text = tokenizer.decode(token_ids)
for special_tokens in tokenizer.special_tokens_map.values():
if isinstance(special_tokens, str):
text = text.replace(f"{special_tokens} ", special_tokens)
elif isinstance(special_tokens, list):
for special_token in special_tokens:
text = text.replace(f"{special_token} ", special_token)
# perform hf behavior
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
pattern = r"<\|image_\d+\|>"
prompt_chunks = [
tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
]
image_tags = [
tokenizer(chunk, add_special_tokens=False).input_ids
for chunk in re.findall(pattern, text)
]
if len(prompt_chunks) > len(image_tags):
image_tags.append([])
token_ids = [
e
for sublist in zip(prompt_chunks, image_tags)
for ele in sublist
for e in ele
]
token_ids, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates,
)
# Keep the behavior in line with HF processor
if len(mm_prompt_updates) and (
token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
tokens=p.tokens,
is_embed=p.is_embed,
)
for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, placeholders
@MULTIMODAL_REGISTRY.register_processor(
Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return f"<|image_{i}|>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
self.quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
prefix="",
# We don't directly initialize vLLM's LlamaForCausalLM so we
# can automatically apply embedding wrapper if this model is
# initialized as an embedding model
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Phi3VImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
return Phi3VImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
},
)
if image_embeds is not None:
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Phi3VImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(
image_input["pixel_values"], image_input["image_sizes"]
)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# The HF config doesn't specify whether these are tied,
# so we detect it this way
if "embed_tokens.weight" not in autoloaded_weights:
self.embed_tokens = self.language_model.model.embed_tokens
autoloaded_weights.add("embed_tokens.weight")
return autoloaded_weights