diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 6a2328f950b84..fbc298b812498 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)` + Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` Note that `num_videos` may be different for each batch, and 'num_frames' may be different for each video, in which case the data is passed as a @@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", - pixel_values_videos=pixel_values_videos, + pixel_values_videos=flatten_bn(pixel_values_videos), ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, for i, patch_features_batch in enumerate(patch_embeddings) ] - def _add_image_newline( - self, - video_features: torch.Tensor, - videos: int = 1, - frames: int = 1, - strategy: str = "one_token", - ) -> torch.Tensor: - if strategy == "one_token": - video_features = video_features.reshape( - videos, frames * video_features.shape[1], -1) - image_newline = self.image_newline[None, None, :].repeat( - videos, 1, 1).to(video_features.device) - video_features = torch.cat((video_features, image_newline), dim=1) - return video_features - raise ValueError(f"Unexpected video newline strategy: {strategy}") - def _video_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], @@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, video_pixels = inputs["pixel_values_videos"] if isinstance(video_pixels, torch.Tensor): - b, num_videos, frames, c, h, w = video_pixels.shape - pixel_values = video_pixels.view(b * num_videos * frames, c, h, w) - stacked_embeddings = self._video_pixels_to_features( - self.vision_tower, pixel_values) - stacked_embeddings = self._add_image_newline(stacked_embeddings, - videos=b * num_videos, - frames=frames, - strategy="one_token") - return stacked_embeddings - elif is_list_of(video_pixels, torch.Tensor): - stacked_embeddings = [] - for video_pixel in video_pixels: - num_videos, frames, c, h, w = video_pixel.shape - pixel_values = video_pixel.view(num_videos * frames, c, h, w) - embeddings = self._video_pixels_to_features( - self.vision_tower, pixel_values) - embeddings = self._add_image_newline(embeddings, - videos=num_videos, - frames=frames, - strategy="one_token") - stacked_embeddings.append(embeddings) - return stacked_embeddings - else: - raise ValueError( - f"Unsupported type of video input {type(video_pixels)}") + total_videos, frames, c, h, w = video_pixels.shape + video_pixels_flat = video_pixels.view(total_videos * frames, c, h, + w) - def apply_pooling(self, image_features, stride=2): + embeddings_flat = self._video_pixels_to_features( + self.vision_tower, video_pixels_flat) + + embeddings_flat = embeddings_flat.reshape( + total_videos, frames * embeddings_flat.shape[1], -1) + + image_newline = self.image_newline[None, None, :].expand( + total_videos, -1, -1) + return torch.cat((embeddings_flat, image_newline), dim=1) + + frames_per_video = [len(video) for video in video_pixels] + video_pixels_flat = torch.cat(video_pixels) + + embeddings_flat = self._video_pixels_to_features( + self.vision_tower, video_pixels_flat) + + image_newline = self.image_newline[None, None, :] + + return [ + torch.cat( + ( + embeds.reshape(1, num_frame * embeddings_flat.shape[1], + -1), + image_newline, + ), + dim=1, + ) for num_frame, embeds in zip( + frames_per_video, + torch.split(embeddings_flat, frames_per_video), + ) + ] + + def apply_pooling(self, image_features: torch.Tensor, stride: int = 2): vision_config = self.config.vision_config height = width = vision_config.image_size // vision_config.patch_size batch_frames, _, dim = image_features.shape