mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
Nemotron Nano V2 VL + EVS Video Support (#27107)
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Natan Bagrov <nbagrov@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Natan Bagrov <nbagrov@nvidia.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
1c691f4a71
commit
e93ff6c8b9
@ -14,6 +14,7 @@ from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import numpy.typing as npt
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
@ -21,7 +22,7 @@ from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -53,12 +54,14 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
@ -91,7 +94,7 @@ IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
|
||||
# Profiling
|
||||
MAX_FRAMES = 16
|
||||
# MAX_FRAMES = 16
|
||||
DEFAULT_NUM_TILES = 12
|
||||
|
||||
|
||||
@ -131,7 +134,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bvf: Batch size * number of videos * num_frames
|
||||
- bn: Batch size * number of images
|
||||
- bn: Batch size * number of videos
|
||||
- f: Number of frames
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each video frame
|
||||
- w: Width of each video frame
|
||||
@ -140,6 +144,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
frames_indices: Annotated[torch.Tensor, TensorShape("bvf")]
|
||||
frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
|
||||
@ -251,6 +257,21 @@ def video_to_pixel_values(
|
||||
return torch.stack(frames_tensors)
|
||||
|
||||
|
||||
def input_conditioner(x, norm_mean, norm_std):
|
||||
return (x - norm_mean) / norm_std
|
||||
|
||||
|
||||
def calculate_timestamps(
|
||||
indices: list[int] | torch.Tensor,
|
||||
frame_duration_ms: int,
|
||||
):
|
||||
if not isinstance(indices, list):
|
||||
indices = indices.tolist()
|
||||
|
||||
timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices]
|
||||
return timestamps
|
||||
|
||||
|
||||
class BaseNanoNemotronVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
@ -344,17 +365,30 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"pixel_values_flat": input_conditioner(
|
||||
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
|
||||
),
|
||||
"image_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst]
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
assert len(text) == 1, (
|
||||
"hf_processor is called on the output of get_dummy_text, "
|
||||
"which should be a single string"
|
||||
)
|
||||
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
|
||||
assert parts.count("<image>") == len(pixel_values_lst), (
|
||||
"the number of <image> tokens in the text should be the "
|
||||
"same as the number of images"
|
||||
)
|
||||
|
||||
for i, pixel_values in enumerate(pixel_values_lst):
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
||||
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||
text = ["".join(parts)]
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
||||
@ -421,6 +455,18 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
self._img_start_token_ids = encode_tokens(
|
||||
tokenizer, IMG_START, add_special_tokens=False
|
||||
)
|
||||
self._img_end_token_ids = encode_tokens(
|
||||
tokenizer, IMG_END, add_special_tokens=False
|
||||
)
|
||||
self._img_context_token_ids = encode_tokens(
|
||||
tokenizer, IMG_CONTEXT, add_special_tokens=False
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_video(self) -> bool:
|
||||
return self.video_token_id is not None
|
||||
@ -454,24 +500,43 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
def _preprocess_video(
|
||||
self,
|
||||
text: list[str],
|
||||
videos: list[npt.NDArray],
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]],
|
||||
max_num_tiles: int,
|
||||
dynamic_image_size: bool | None = None,
|
||||
):
|
||||
if len(videos) == 0 or not self.supports_video:
|
||||
video_inputs = {}
|
||||
else:
|
||||
videos_lst = [v[0] for v in videos]
|
||||
video_metadata_lst = [v[1] for v in videos]
|
||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||
videos,
|
||||
videos_lst,
|
||||
max_num_tiles=max_num_tiles,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
# We use frame duration in milliseconds (as integer) to ensure
|
||||
# we have consistent timestamps calculation. At preprocessing
|
||||
# fps parameter is given in fp32, while at inference it is bf16
|
||||
# which leads to inaccurate timestamp calculation and causes
|
||||
# timestamp values to differ.In rare cases this causes
|
||||
# mismatching number of output tokens for tokenized frame prefixes
|
||||
frame_duration_ms_lst = [
|
||||
int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst
|
||||
]
|
||||
frames_indices_lst = [
|
||||
metadata["frames_indices"] for metadata in video_metadata_lst
|
||||
]
|
||||
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
|
||||
"pixel_values_flat_video": input_conditioner(
|
||||
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
|
||||
),
|
||||
"video_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst_video]
|
||||
),
|
||||
"frames_indices": frames_indices_lst,
|
||||
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
||||
}
|
||||
|
||||
image_size: int = self.config.force_image_size
|
||||
@ -481,7 +546,12 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
||||
)
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
|
||||
pixel_values_lst_video,
|
||||
video_metadata_lst,
|
||||
frames_indices_lst,
|
||||
frame_duration_ms_lst,
|
||||
):
|
||||
num_frames = pixel_values.shape[0]
|
||||
|
||||
if (
|
||||
@ -504,16 +574,29 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
else:
|
||||
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
||||
|
||||
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
|
||||
video_repl = self.get_video_repl(
|
||||
tokens_per_frame=tokens_per_frame,
|
||||
frames_indices=frames_indices,
|
||||
frame_duration_ms=frame_duration_ms,
|
||||
tokenizer=self.tokenizer,
|
||||
img_start_token_ids=self._img_start_token_ids,
|
||||
img_end_token_ids=self._img_end_token_ids,
|
||||
img_context_token_ids=self._img_context_token_ids,
|
||||
)
|
||||
|
||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
||||
# video_repl.full is a list of token IDs
|
||||
# Convert token IDs back to text for the HF processor flow
|
||||
video_repl_text = self.tokenizer.decode(
|
||||
video_repl.full, skip_special_tokens=False
|
||||
)
|
||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
videos: npt.NDArray | list[npt.NDArray] | None = None,
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
@ -558,9 +641,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
@classmethod
|
||||
def get_video_repl(
|
||||
cls,
|
||||
*,
|
||||
tokens_per_frame: list[int],
|
||||
video_context_token: str = IMG_CONTEXT,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
frames_indices: list[int],
|
||||
frame_duration_ms: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
img_start_token_ids: list[int],
|
||||
img_end_token_ids: list[int],
|
||||
img_context_token_ids: list[int],
|
||||
) -> PromptUpdateDetails[list[int]]:
|
||||
"""
|
||||
Build prompt replacement for a video.
|
||||
The replacement returned is not actually used to replace the placeholder
|
||||
@ -579,16 +668,52 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
||||
Args:
|
||||
tokens_per_frame (list[int]): number of tokens per frame
|
||||
video_context_token (str): the token to use for the video context
|
||||
frames_indices (list[int]): frame indices
|
||||
frame_duration_ms (int): duration of each frame in milliseconds
|
||||
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
|
||||
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
|
||||
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
|
||||
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
|
||||
"""
|
||||
repl_full = "".join(
|
||||
[
|
||||
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
|
||||
for i, num_tokens in enumerate(tokens_per_frame)
|
||||
]
|
||||
)
|
||||
# TODO: Add support of frame_duration_ms to be None
|
||||
# At preprocessing step we should allow absent / metadata without
|
||||
# frames_indices field.
|
||||
timestamps_enabled = frame_duration_ms is not None
|
||||
|
||||
return PromptUpdateDetails.from_seq(repl_full)
|
||||
if timestamps_enabled:
|
||||
timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
|
||||
|
||||
assert len(timestamps) == len(tokens_per_frame), (
|
||||
"timestamps and tokens_per_frame must have the same length"
|
||||
)
|
||||
frame_separators = [
|
||||
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
|
||||
for i, timestamp in enumerate(timestamps)
|
||||
]
|
||||
else:
|
||||
frame_separators = [
|
||||
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame)
|
||||
]
|
||||
|
||||
# Tokenize frame separator independently
|
||||
frame_separators_tokenized = [
|
||||
_seq2tokens(tokenizer, sep) for sep in frame_separators
|
||||
]
|
||||
|
||||
# Tokenize each component independently to avoid tokenizer merging tokens
|
||||
# across boundaries. This ensures consistent tokenization regardless of
|
||||
# num_tokens_per_frame values.
|
||||
all_token_ids = []
|
||||
for i, num_tokens in enumerate(tokens_per_frame):
|
||||
frame_sep_token_ids = frame_separators_tokenized[i]
|
||||
all_token_ids.extend(frame_sep_token_ids)
|
||||
|
||||
# Add pre-tokenized special tokens
|
||||
all_token_ids.extend(img_start_token_ids)
|
||||
all_token_ids.extend(img_context_token_ids * num_tokens)
|
||||
all_token_ids.extend(img_end_token_ids)
|
||||
|
||||
return PromptUpdateDetails.from_seq(all_token_ids)
|
||||
|
||||
|
||||
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
@ -695,8 +820,6 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
|
||||
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||
|
||||
max_frames_per_video = min(max_frames_per_video, MAX_FRAMES)
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
||||
@ -791,6 +914,9 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
):
|
||||
"""MultiModalProcessor extended for video support"""
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(video_needs_metadata=True)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@ -805,6 +931,8 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
"video", video_num_patches
|
||||
),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
frames_indices=MultiModalFieldConfig.batched("video"),
|
||||
frame_duration_ms=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
else:
|
||||
video_fields = {}
|
||||
@ -835,6 +963,7 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
|
||||
def get_video_replacement_internvl(item_idx: int):
|
||||
feature_size = hf_processor.num_image_token
|
||||
video, metadata = mm_items["video"][item_idx]
|
||||
num_patches = video_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
@ -856,9 +985,15 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
else:
|
||||
tokens_per_frame = [feature_size] * num_patches
|
||||
|
||||
frame_duration_ms = int(1000 / metadata["fps"])
|
||||
return hf_processor.get_video_repl(
|
||||
tokens_per_frame,
|
||||
video_context_token=hf_processor.video_token,
|
||||
tokens_per_frame=tokens_per_frame,
|
||||
frames_indices=metadata["frames_indices"],
|
||||
frame_duration_ms=frame_duration_ms,
|
||||
tokenizer=hf_processor.tokenizer,
|
||||
img_start_token_ids=hf_processor._img_start_token_ids,
|
||||
img_end_token_ids=hf_processor._img_end_token_ids,
|
||||
img_context_token_ids=hf_processor._img_context_token_ids,
|
||||
)
|
||||
|
||||
if self.info.supports_video:
|
||||
@ -917,6 +1052,37 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
|
||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: VideoDummyOptions | None = None,
|
||||
) -> list[VideoItem]:
|
||||
video = super()._get_dummy_videos(
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=num_frames,
|
||||
num_videos=1,
|
||||
overrides=overrides,
|
||||
)[0]
|
||||
video_items = []
|
||||
for _ in range(num_videos):
|
||||
video_metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": 2,
|
||||
"duration": num_frames / 2.0,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": [i for i in range(num_frames)],
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
video_item = (video.copy(), video_metadata)
|
||||
video_items.append(video_item)
|
||||
|
||||
return video_items
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
@ -1013,6 +1179,19 @@ class NemotronH_Nano_VL_V2(
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
self._img_start_token_ids = encode_tokens(
|
||||
tokenizer, IMG_START, add_special_tokens=False
|
||||
)
|
||||
self._img_end_token_ids = encode_tokens(
|
||||
tokenizer, IMG_END, add_special_tokens=False
|
||||
)
|
||||
self._img_context_token_ids = encode_tokens(
|
||||
tokenizer, IMG_CONTEXT, add_special_tokens=False
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
@ -1043,13 +1222,28 @@ class NemotronH_Nano_VL_V2(
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
vit_embeds = self.vision_model(pixel_values)
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
# Process images in a micro-batch of at most 128 frames per call
|
||||
# This is done on purpose to ensure peak GPU ram usage of huge batch
|
||||
# (namely for really long videos with EVS ON) won't cause any problems
|
||||
# as we don't support chunked prefill for video media
|
||||
micro_batch_size = 128
|
||||
n = pixel_values.shape[0]
|
||||
vit_embeds_list = []
|
||||
for i in range(0, n, micro_batch_size):
|
||||
vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
)
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
)
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
vit_embeds_list.append(vit_embeds)
|
||||
|
||||
vit_embeds = torch.cat(vit_embeds_list, dim=0)
|
||||
return vit_embeds
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
@ -1117,12 +1311,15 @@ class NemotronH_Nano_VL_V2(
|
||||
rows = int(image_rows * downsample_ratio // patch_size)
|
||||
cols = int(image_cols * downsample_ratio // patch_size)
|
||||
video_pruning_rate = self.video_pruning_rate
|
||||
|
||||
video_num_frames = video_input["num_patches"].tolist()
|
||||
video_frames_indices = video_input["frames_indices"].split(video_num_frames)
|
||||
# Calculate video feature dimensions (number of frames and
|
||||
# their feature size (AKA tokens per frame))
|
||||
# TODO: Maybe this can be optimized to avoid the loop?
|
||||
for i, single_video_embeddings in enumerate(video_embeddings):
|
||||
num_frames = video_input["num_patches"][i].item()
|
||||
num_frames = video_num_frames[i]
|
||||
frames_indices = video_frames_indices[i].tolist()
|
||||
frame_duration_ms = video_input["frame_duration_ms"][i].item()
|
||||
assert single_video_embeddings.shape[0] % num_frames == 0
|
||||
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
@ -1151,6 +1348,8 @@ class NemotronH_Nano_VL_V2(
|
||||
self._create_final_video_embeddings(
|
||||
single_video_embeddings,
|
||||
num_tokens_per_frame,
|
||||
frames_indices,
|
||||
frame_duration_ms,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1160,6 +1359,8 @@ class NemotronH_Nano_VL_V2(
|
||||
self,
|
||||
video_embeddings: torch.Tensor,
|
||||
num_tokens_per_frame: list[int],
|
||||
frames_indices: list[int],
|
||||
frame_duration_ms: int,
|
||||
) -> torch.Tensor:
|
||||
"""Create final embeddings that combine video embeddings with
|
||||
text embeddings of indicator tokens.
|
||||
@ -1173,22 +1374,27 @@ class NemotronH_Nano_VL_V2(
|
||||
input_embeds for the LLM.
|
||||
"""
|
||||
device = video_embeddings.device
|
||||
|
||||
# Generate video replacement text and convert to token IDs
|
||||
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
|
||||
num_tokens_per_frame,
|
||||
IMG_CONTEXT,
|
||||
).full
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
repl_token_ids = torch.tensor(
|
||||
_seq2tokens(tokenizer, video_repl_text), device=device
|
||||
|
||||
# Generate video replacement token IDs using get_video_repl
|
||||
# This tokenizes each frame separator independently, then uses pre-tokenized
|
||||
# special tokens to ensure consistent tokenization regardless of
|
||||
# num_tokens_per_frame values.
|
||||
video_repl = NanoNemotronVLProcessor.get_video_repl(
|
||||
tokens_per_frame=num_tokens_per_frame,
|
||||
frames_indices=frames_indices,
|
||||
frame_duration_ms=frame_duration_ms,
|
||||
tokenizer=tokenizer,
|
||||
img_start_token_ids=self._img_start_token_ids,
|
||||
img_end_token_ids=self._img_end_token_ids,
|
||||
img_context_token_ids=self._img_context_token_ids,
|
||||
)
|
||||
|
||||
# Get embedding token IDs for image context
|
||||
embed_token_ids = torch.tensor(
|
||||
encode_tokens(tokenizer, IMG_CONTEXT), device=device
|
||||
)
|
||||
# video_repl.full is a list of token IDs
|
||||
repl_token_ids = torch.tensor(video_repl.full, device=device)
|
||||
|
||||
# Get embedding token IDs for image context (use pre-tokenized version)
|
||||
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
|
||||
|
||||
# Create mask for video embedding positions
|
||||
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
||||
@ -1210,6 +1416,8 @@ class NemotronH_Nano_VL_V2(
|
||||
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||
video_embeds = kwargs.pop("video_embeds", None)
|
||||
frames_indices = kwargs.pop("frames_indices", None)
|
||||
frame_duration_ms = kwargs.pop("frame_duration_ms", None)
|
||||
|
||||
if pixel_values_flat_video is None and video_embeds is None:
|
||||
return None
|
||||
@ -1221,13 +1429,22 @@ class NemotronH_Nano_VL_V2(
|
||||
)
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if torch.is_tensor(frames_indices):
|
||||
frames_indices = frames_indices.flatten()
|
||||
else:
|
||||
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
|
||||
|
||||
frame_duration_ms = frame_duration_ms.flatten()
|
||||
expected_h = expected_w = self.config.force_image_size
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
num_frames = video_num_patches[0].item()
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
|
||||
|
||||
return NanoNemotronVLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_flat=pixel_values_flat_video,
|
||||
num_patches=video_num_patches,
|
||||
frames_indices=frames_indices,
|
||||
frame_duration_ms=frame_duration_ms,
|
||||
resolve_bindings=resolve_bindings,
|
||||
)
|
||||
|
||||
|
||||
@ -43,32 +43,6 @@ to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
class InputConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_scale: float,
|
||||
norm_mean: norm_t,
|
||||
norm_std: norm_t,
|
||||
dtype: torch.dtype = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
|
||||
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
y = (x - self.norm_mean) / self.norm_std
|
||||
if self.dtype is not None:
|
||||
y = y.to(self.dtype)
|
||||
return y
|
||||
|
||||
|
||||
def _to_tensor(v: norm_t):
|
||||
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
|
||||
|
||||
|
||||
class ClsToken(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -507,11 +481,6 @@ class RadioModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.input_conditioner = InputConditioner(
|
||||
input_scale=1.0,
|
||||
norm_mean=config.norm_mean,
|
||||
norm_std=config.norm_std,
|
||||
)
|
||||
self.model = RadioInternVisionModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@ -525,8 +494,7 @@ class RadioModel(nn.Module):
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
x = self.input_conditioner(pixel_values)
|
||||
y = self.model(x)
|
||||
y = self.model(pixel_values)
|
||||
return self._extract_final(y)
|
||||
|
||||
def load_weights(self, weights) -> set[str]:
|
||||
@ -548,6 +516,10 @@ class RadioModel(nn.Module):
|
||||
# Skip buffers not used in vLLM
|
||||
if sub in {"summary_idxs"}:
|
||||
continue
|
||||
if sub.startswith("input_conditioner."):
|
||||
# we normalize in the input processor,
|
||||
# based on norm and std values from the config
|
||||
continue
|
||||
|
||||
vllm_key = None
|
||||
if sub.startswith("model.patch_generator."):
|
||||
|
||||
@ -223,7 +223,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255)
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
|
||||
@ -13,10 +13,13 @@ import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
@ -103,6 +106,7 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = -1,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
@ -116,14 +120,20 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
full_read = num_frames == -1 or total_frames_num < num_frames
|
||||
if full_read:
|
||||
num_frames = total_frames_num
|
||||
frame_idx = list(range(0, num_frames))
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames, dtype=int
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
@ -132,7 +142,7 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
for idx in range(max(frame_idx) + 1):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
break
|
||||
@ -142,8 +152,8 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
|
||||
assert i == num_frames, (
|
||||
f"Expected reading {num_frames} frames, "
|
||||
assert i == num_frames_to_sample, (
|
||||
f"Expected reading {num_frames_to_sample} frames, "
|
||||
f"but only loaded {i} frames from video."
|
||||
)
|
||||
|
||||
@ -151,14 +161,14 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": num_frames / duration,
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
"frames_indices": list(frame_idx),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames == total_frames_num,
|
||||
"do_sample_frames": num_frames_to_sample == total_frames_num,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
@ -1735,20 +1735,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
curr_group_outputs = []
|
||||
|
||||
# EVS-related change.
|
||||
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
||||
# processing multimodal data.This solves the issue with scheduler
|
||||
# processing multimodal data. This solves the issue with scheduler
|
||||
# putting too many video samples into a single batch. Scheduler
|
||||
# uses pruned vision tokens count to compare it versus compute
|
||||
# budget which is incorrect (Either input media size or non-pruned
|
||||
# output vision tokens count should be considered)
|
||||
curr_group_outputs = []
|
||||
|
||||
if self.is_multimodal_pruning_enabled and modality == "video":
|
||||
micro_batch_size = 1
|
||||
for i in range(0, num_items, micro_batch_size):
|
||||
micro_batch_mm_inputs = dict(
|
||||
(k, v[i : i + micro_batch_size])
|
||||
for k, v in mm_kwargs_group.items()
|
||||
# TODO(ywang96): Fix memory profiling to take EVS into account and
|
||||
# remove this hack.
|
||||
if (
|
||||
self.is_multimodal_pruning_enabled
|
||||
and modality == "video"
|
||||
and num_items > 1
|
||||
):
|
||||
for video_mm_kwargs_item in filter(
|
||||
lambda item: item.modality == "video", mm_kwargs
|
||||
):
|
||||
_, _, micro_batch_mm_inputs = next(
|
||||
group_mm_kwargs_by_modality(
|
||||
[video_mm_kwargs_item],
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
)
|
||||
)
|
||||
|
||||
micro_batch_outputs = model.get_multimodal_embeddings(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user