# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.collection_utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from .vision import get_vision_encoder_info class LlavaNextVideoPixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of videos - f: Number of frames - c: Number of channels (3) - h: Height of each frame - w: Width of each frame Note that `f` may be different for each batch, in which case the data is passed as a list instead of a batched tensor. Note that it only supports one video input for one batch. """ type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), ] class LlavaNextVideoProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(LlavaNextVideoProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"video": 1} def get_image_size_with_most_features(self) -> ImageSize: vision_encoder_info = self.get_vision_encoder_info() width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) def _get_num_frame_tokens( self, *, image_width: int, image_height: int, ) -> int: hf_config = self.get_hf_config() spatial_pool_stride = hf_config.spatial_pool_stride vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, ) -> int: num_frame_tokens = self._get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, ) if next_max_tokens > max_tokens: break num_frames = next_num_frames return num_frames def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) return max(max_total_frames // max(max_videos, 1), 1) class LlavaNextVideoDummyInputsBuilder( BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo] ): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_videos = mm_counts.get("video", 0) processor = self.info.get_hf_processor() video_token = processor.video_token return video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_videos = mm_counts.get("video", 0) target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts ) video_overrides = mm_options.get("video") if mm_options else None return { "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, ) } class LlavaNextVideoMultiModalProcessor( BaseMultiModalProcessor[LlavaNextVideoProcessingInfo] ): def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_replacement(item_idx: int): videos = mm_items.get_items( "video", (VideoEmbeddingItems, VideoProcessorItems) ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), ) return [video_token_id] * num_video_tokens return [ PromptReplacement( modality="video", target=[video_token_id], replacement=get_replacement, ), ] # adopted from transformers modeling_llava_next_video.py class LlavaNextVideoPooler(nn.Module): def __init__(self, config: LlavaNextVideoConfig): super().__init__() mode = config.spatial_pool_mode stride = config.spatial_pool_stride image_size = config.vision_config.image_size patch_size = config.vision_config.patch_size self.image_size = image_size // patch_size**2 if mode == "average": self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride) elif mode == "max": self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride) else: # TODO: Support Conv2d pooling layer, need to load weights raise ValueError( f"Unknown pooling mode: {mode}. Expected [`average`, `max`]" ) def forward(self, image_features: torch.Tensor): ori_width = int( math.sqrt(image_features.shape[1] * self.image_size // self.image_size) ) ori_height = int(ori_width * self.image_size // self.image_size) batch_size, _, dim = image_features.shape image_features_spatial = image_features.view( batch_size, ori_height, ori_height, dim ).permute(0, 3, 1, 2) image_features_spatial = self.pool(image_features_spatial) return image_features_spatial.flatten(2).transpose(1, 2).contiguous() class LlavaNextMultiModalProjector(nn.Module): def __init__( self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str, multimodal_projector_bias: bool, ): super().__init__() self.linear_1 = nn.Linear( vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias ) self.act = get_act_fn(projector_hidden_act) self.linear_2 = nn.Linear( text_hidden_size, text_hidden_size, bias=multimodal_projector_bias ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states @MULTIMODAL_REGISTRY.register_processor( LlavaNextVideoMultiModalProcessor, info=LlavaNextVideoProcessingInfo, dummy_inputs=LlavaNextVideoDummyInputsBuilder, ) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" if modality.startswith("video"): return "