vllm/vllm/model_executor/models/llava_next.py
Harry Mellor 97d1c99302
Rename clashing method names for vLLM model protocol (#27583)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-11-12 19:14:33 -08:00

584 lines
22 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable, Mapping
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
import torch
import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape,
unpad_image,
)
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import (
BaseLlavaMultiModalProcessor,
BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder,
LlavaLikeConfig,
LlavaMultiModalProjector,
init_vision_tower_for_llava,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_num_selected_vision_tokens
class LlavaNextImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
]
image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
# This should be in `(height, width)` format.
class LlavaNextImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs: TypeAlias = (
LlavaNextImagePixelInputs | LlavaNextImageEmbeddingInputs
)
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]]
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
def get_hf_processor(self, **kwargs: object):
hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
# In case patch_size is omitted from `processor_config.json`
# e.g. for E5-V: https://huggingface.co/royokong/e5-v
if hf_processor.patch_size is None:
patch_size = self.get_vision_encoder_info().get_patch_size()
hf_processor.patch_size = patch_size
return hf_processor
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
base_feature_size = get_num_selected_vision_tokens(
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
hf_config.vision_feature_select_strategy,
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_encoder_info.get_image_size(),
)
(
unpadded_feature_size,
newline_feature_size,
) = self._get_num_unpadded_features(
original_height=image_height,
original_width=image_width,
npatches=vision_encoder_info.get_patch_grid_length(),
num_patch_height=num_patch_height,
num_patch_width=num_patch_width,
)
return unpadded_feature_size + newline_feature_size + base_feature_size
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = int(
round(original_height * (current_width / original_width), 7)
)
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = int(
round(original_width * (current_height / original_height), 7)
)
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for height, width in hf_config.image_grid_pinpoints:
feat_size = self.get_num_image_tokens(
image_width=width, image_height=height
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.select_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer
)
self.select_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
" is not supported"
)
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> LlavaNextImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": expected_h,
"w": expected_w,
},
)
if image_embeds is not None:
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
vision_tower: CLIPVisionModel | SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
return vision_tower(
pixel_values,
select_layers=self.select_layers,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(
self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str
) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches].view(
num_patch_height, num_patch_width, height, width, -1
)
if "unpad" in strategy:
other_patch_embeds = (
other_patch_embeds.permute(4, 0, 2, 1, 3)
.contiguous()
.flatten(1, 2)
.flatten(2, 3)
)
other_patch_embeds = unpad_image(
other_patch_embeds, (orig_height, orig_width)
)
other_patch_embeds = torch.cat(
(
other_patch_embeds,
self.image_newline[:, None, None]
.expand(*other_patch_embeds.shape[:-1], 1)
.to(other_patch_embeds.device),
),
dim=-1,
)
other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
0, 1
)
else:
other_patch_embeds = (
other_patch_embeds.permute(0, 2, 1, 3, 4)
.contiguous()
.flatten(0, 3)
)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0
)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(
base_patch_embeds,
self.image_newline[None].to(base_patch_embeds.device),
),
dim=0,
)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values
)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features
)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:]
)
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values
)
return torch.split(
self.multi_modal_projector(stacked_image_features), num_patches_per_batch
)
def _process_image_input(
self,
image_input: LlavaNextImageInputs,
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor(
[[default_height, default_width] for _ in range(batch_size)]
)
return [
self._merge_image_patch_embeddings(
image_sizes[i], patch_features_batch, strategy="spatial_unpad"
)
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"A chat between a curious human and an artificial intelligence
assistant. The assistant gives helpful, detailed, and polite answers to
the human's questions.
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
Tokenizer outputs:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
319, 1799, 9047, 13566, 29901]`.
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
model depends on the original size of the input image. Including the
original image token in the input, the required number of image tokens
is given by [`LlavaNextProcessingInfo.get_num_image_tokens`][vllm.\
model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
Info:
[`LlavaNextImageInputs`][vllm.model_executor.models.llava_next.LlavaNextImageInputs]
"""
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)