diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 77d77e7b9f86..86fc1d6046ce 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -14,6 +14,7 @@ from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt +import regex as re import torch import torch.nn as nn import torchvision.transforms as T @@ -21,7 +22,7 @@ from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -53,12 +54,14 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, + VideoItem, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, + MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -91,7 +94,7 @@ IMG_END = "" IMG_CONTEXT = "" # Profiling -MAX_FRAMES = 16 +# MAX_FRAMES = 16 DEFAULT_NUM_TILES = 12 @@ -131,7 +134,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): """ Dimensions: - bvf: Batch size * number of videos * num_frames - - bn: Batch size * number of images + - bn: Batch size * number of videos + - f: Number of frames - c: Number of channels (3) - h: Height of each video frame - w: Width of each video frame @@ -140,6 +144,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] + frames_indices: Annotated[torch.Tensor, TensorShape("bvf")] + frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")] class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): @@ -251,6 +257,21 @@ def video_to_pixel_values( return torch.stack(frames_tensors) +def input_conditioner(x, norm_mean, norm_std): + return (x - norm_mean) / norm_std + + +def calculate_timestamps( + indices: list[int] | torch.Tensor, + frame_duration_ms: int, +): + if not isinstance(indices, list): + indices = indices.tolist() + + timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices] + return timestamps + + class BaseNanoNemotronVLProcessor(ABC): """ This model doesn't define its own HF processor, @@ -344,17 +365,30 @@ class BaseNanoNemotronVLProcessor(ABC): else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), + "pixel_values_flat": input_conditioner( + torch.cat(pixel_values_lst), self.norm_mean, self.norm_std + ), "image_num_patches": torch.tensor( [len(item) for item in pixel_values_lst] ), } - for pixel_values in pixel_values_lst: + assert len(text) == 1, ( + "hf_processor is called on the output of get_dummy_text, " + "which should be a single string" + ) + parts = [x for x in re.split(r"()", text[0]) if x] + assert parts.count("") == len(pixel_values_lst), ( + "the number of tokens in the text should be the " + "same as the number of images" + ) + + for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace("", image_repl.full, 1) for t in text] + parts[i] = parts[i].replace("", image_repl.full) + text = ["".join(parts)] return text, image_inputs def _make_batch_input(self, input_item: Any | list[Any] | None = None): @@ -421,6 +455,18 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): self.video_token = video_token self.video_pruning_rate = video_pruning_rate + # Pre-tokenize special tokens for video processing + # to avoid repeated tokenization + self._img_start_token_ids = encode_tokens( + tokenizer, IMG_START, add_special_tokens=False + ) + self._img_end_token_ids = encode_tokens( + tokenizer, IMG_END, add_special_tokens=False + ) + self._img_context_token_ids = encode_tokens( + tokenizer, IMG_CONTEXT, add_special_tokens=False + ) + @property def supports_video(self) -> bool: return self.video_token_id is not None @@ -454,24 +500,43 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): def _preprocess_video( self, text: list[str], - videos: list[npt.NDArray], + videos: list[tuple[npt.NDArray, dict[str, Any]]], max_num_tiles: int, dynamic_image_size: bool | None = None, ): if len(videos) == 0 or not self.supports_video: video_inputs = {} else: + videos_lst = [v[0] for v in videos] + video_metadata_lst = [v[1] for v in videos] pixel_values_lst_video = self._videos_to_pixel_values_lst( - videos, + videos_lst, max_num_tiles=max_num_tiles, dynamic_image_size=dynamic_image_size, ) + # We use frame duration in milliseconds (as integer) to ensure + # we have consistent timestamps calculation. At preprocessing + # fps parameter is given in fp32, while at inference it is bf16 + # which leads to inaccurate timestamp calculation and causes + # timestamp values to differ.In rare cases this causes + # mismatching number of output tokens for tokenized frame prefixes + frame_duration_ms_lst = [ + int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst + ] + frames_indices_lst = [ + metadata["frames_indices"] for metadata in video_metadata_lst + ] + video_inputs = { - "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "pixel_values_flat_video": input_conditioner( + torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std + ), "video_num_patches": torch.tensor( [len(item) for item in pixel_values_lst_video] ), + "frames_indices": frames_indices_lst, + "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } image_size: int = self.config.force_image_size @@ -481,7 +546,12 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): (image_size * image_size // patch_size**2) * (downsample_ratio**2) ) - for pixel_values in pixel_values_lst_video: + for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip( + pixel_values_lst_video, + video_metadata_lst, + frames_indices_lst, + frame_duration_ms_lst, + ): num_frames = pixel_values.shape[0] if ( @@ -504,16 +574,29 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): else: tokens_per_frame = [tokens_in_single_frame] * num_frames - video_repl = self.get_video_repl(tokens_per_frame, self.video_token) + video_repl = self.get_video_repl( + tokens_per_frame=tokens_per_frame, + frames_indices=frames_indices, + frame_duration_ms=frame_duration_ms, + tokenizer=self.tokenizer, + img_start_token_ids=self._img_start_token_ids, + img_end_token_ids=self._img_end_token_ids, + img_context_token_ids=self._img_context_token_ids, + ) - text = [t.replace("