mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 10:14:29 +08:00
Migrate LlavaNextVideoPixelInputs to TensorSchema (#21843)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
d1af8b7be9
commit
a554991748
@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaNextVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_frames, num_channels, height, width)`
|
||||
class LlavaNextVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bs: Batch size
|
||||
- nv: Number of videos
|
||||
- nf: Number of frames
|
||||
- nc: Number of channels (3)
|
||||
- h: Height of each frame
|
||||
- w: Width of each frame
|
||||
|
||||
Note that `num_frames` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
|
||||
Note that it only supports one video input for one batch.
|
||||
"""
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bs", "nv", "nf", 3, "h", "w")]
|
||||
|
||||
|
||||
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_video_pixel_values(
|
||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[2:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_frames", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values in each video frame "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
|
||||
"""
|
||||
@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
|
||||
return LlavaNextVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
data=pixel_values_videos,
|
||||
)
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
return LlavaNextVideoPixelInputs(type="pixel_values_videos",
|
||||
data=pixel_values_videos,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user