Nemotron Nano V2 VL + EVS Video Support (#27107)

Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Natan Bagrov <nbagrov@nvidia.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Eugene Khvedchenya 2025-10-20 17:19:11 +03:00 committed by GitHub
parent 1c691f4a71
commit e93ff6c8b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 317 additions and 106 deletions

View File

@ -14,6 +14,7 @@ from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt
import regex as re
import torch
import torch.nn as nn
import torchvision.transforms as T
@ -21,7 +22,7 @@ from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -53,12 +54,14 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargs,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
@ -91,7 +94,7 @@ IMG_END = "</img>"
IMG_CONTEXT = "<image>"
# Profiling
MAX_FRAMES = 16
# MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12
@ -131,7 +134,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
"""
Dimensions:
- bvf: Batch size * number of videos * num_frames
- bn: Batch size * number of images
- bn: Batch size * number of videos
- f: Number of frames
- c: Number of channels (3)
- h: Height of each video frame
- w: Width of each video frame
@ -140,6 +144,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
frames_indices: Annotated[torch.Tensor, TensorShape("bvf")]
frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
@ -251,6 +257,21 @@ def video_to_pixel_values(
return torch.stack(frames_tensors)
def input_conditioner(x, norm_mean, norm_std):
return (x - norm_mean) / norm_std
def calculate_timestamps(
indices: list[int] | torch.Tensor,
frame_duration_ms: int,
):
if not isinstance(indices, list):
indices = indices.tolist()
timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices]
return timestamps
class BaseNanoNemotronVLProcessor(ABC):
"""
This model doesn't define its own HF processor,
@ -344,17 +365,30 @@ class BaseNanoNemotronVLProcessor(ABC):
else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_inputs = {
"pixel_values_flat": torch.cat(pixel_values_lst),
"pixel_values_flat": input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
),
"image_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst]
),
}
for pixel_values in pixel_values_lst:
assert len(text) == 1, (
"hf_processor is called on the output of get_dummy_text, "
"which should be a single string"
)
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), (
"the number of <image> tokens in the text should be the "
"same as the number of images"
)
for i, pixel_values in enumerate(pixel_values_lst):
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
text = [t.replace("<image>", image_repl.full, 1) for t in text]
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
return text, image_inputs
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
@ -421,6 +455,18 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
self._img_start_token_ids = encode_tokens(
tokenizer, IMG_START, add_special_tokens=False
)
self._img_end_token_ids = encode_tokens(
tokenizer, IMG_END, add_special_tokens=False
)
self._img_context_token_ids = encode_tokens(
tokenizer, IMG_CONTEXT, add_special_tokens=False
)
@property
def supports_video(self) -> bool:
return self.video_token_id is not None
@ -454,24 +500,43 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def _preprocess_video(
self,
text: list[str],
videos: list[npt.NDArray],
videos: list[tuple[npt.NDArray, dict[str, Any]]],
max_num_tiles: int,
dynamic_image_size: bool | None = None,
):
if len(videos) == 0 or not self.supports_video:
video_inputs = {}
else:
videos_lst = [v[0] for v in videos]
video_metadata_lst = [v[1] for v in videos]
pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos,
videos_lst,
max_num_tiles=max_num_tiles,
dynamic_image_size=dynamic_image_size,
)
# We use frame duration in milliseconds (as integer) to ensure
# we have consistent timestamps calculation. At preprocessing
# fps parameter is given in fp32, while at inference it is bf16
# which leads to inaccurate timestamp calculation and causes
# timestamp values to differ.In rare cases this causes
# mismatching number of output tokens for tokenized frame prefixes
frame_duration_ms_lst = [
int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst
]
frames_indices_lst = [
metadata["frames_indices"] for metadata in video_metadata_lst
]
video_inputs = {
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
"pixel_values_flat_video": input_conditioner(
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"video_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst_video]
),
"frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
}
image_size: int = self.config.force_image_size
@ -481,7 +546,12 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
)
for pixel_values in pixel_values_lst_video:
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
pixel_values_lst_video,
video_metadata_lst,
frames_indices_lst,
frame_duration_ms_lst,
):
num_frames = pixel_values.shape[0]
if (
@ -504,16 +574,29 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
else:
tokens_per_frame = [tokens_in_single_frame] * num_frames
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
video_repl = self.get_video_repl(
tokens_per_frame=tokens_per_frame,
frames_indices=frames_indices,
frame_duration_ms=frame_duration_ms,
tokenizer=self.tokenizer,
img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids,
)
text = [t.replace("<video>", video_repl.full, 1) for t in text]
# video_repl.full is a list of token IDs
# Convert token IDs back to text for the HF processor flow
video_repl_text = self.tokenizer.decode(
video_repl.full, skip_special_tokens=False
)
text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs
def __call__(
self,
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
videos: npt.NDArray | list[npt.NDArray] | None = None,
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
dynamic_image_size: bool | None = None,
@ -558,9 +641,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
@classmethod
def get_video_repl(
cls,
*,
tokens_per_frame: list[int],
video_context_token: str = IMG_CONTEXT,
) -> PromptUpdateDetails[str]:
frames_indices: list[int],
frame_duration_ms: int,
tokenizer: AnyTokenizer,
img_start_token_ids: list[int],
img_end_token_ids: list[int],
img_context_token_ids: list[int],
) -> PromptUpdateDetails[list[int]]:
"""
Build prompt replacement for a video.
The replacement returned is not actually used to replace the placeholder
@ -579,16 +668,52 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
Args:
tokens_per_frame (list[int]): number of tokens per frame
video_context_token (str): the token to use for the video context
frames_indices (list[int]): frame indices
frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
"""
repl_full = "".join(
[
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
for i, num_tokens in enumerate(tokens_per_frame)
]
)
# TODO: Add support of frame_duration_ms to be None
# At preprocessing step we should allow absent / metadata without
# frames_indices field.
timestamps_enabled = frame_duration_ms is not None
return PromptUpdateDetails.from_seq(repl_full)
if timestamps_enabled:
timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length"
)
frame_separators = [
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
for i, timestamp in enumerate(timestamps)
]
else:
frame_separators = [
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame)
]
# Tokenize frame separator independently
frame_separators_tokenized = [
_seq2tokens(tokenizer, sep) for sep in frame_separators
]
# Tokenize each component independently to avoid tokenizer merging tokens
# across boundaries. This ensures consistent tokenization regardless of
# num_tokens_per_frame values.
all_token_ids = []
for i, num_tokens in enumerate(tokens_per_frame):
frame_sep_token_ids = frame_separators_tokenized[i]
all_token_ids.extend(frame_sep_token_ids)
# Add pre-tokenized special tokens
all_token_ids.extend(img_start_token_ids)
all_token_ids.extend(img_context_token_ids * num_tokens)
all_token_ids.extend(img_end_token_ids)
return PromptUpdateDetails.from_seq(all_token_ids)
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
@ -695,8 +820,6 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
max_frames_per_video = max_total_frames // max(max_videos, 1)
max_frames_per_video = min(max_frames_per_video, MAX_FRAMES)
return max(max_frames_per_video, 1)
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
@ -791,6 +914,9 @@ class NanoNemotronVLMultiModalProcessor(
):
"""MultiModalProcessor extended for video support"""
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -805,6 +931,8 @@ class NanoNemotronVLMultiModalProcessor(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
frames_indices=MultiModalFieldConfig.batched("video"),
frame_duration_ms=MultiModalFieldConfig.batched("video"),
)
else:
video_fields = {}
@ -835,6 +963,7 @@ class NanoNemotronVLMultiModalProcessor(
def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
@ -856,9 +985,15 @@ class NanoNemotronVLMultiModalProcessor(
else:
tokens_per_frame = [feature_size] * num_patches
frame_duration_ms = int(1000 / metadata["fps"])
return hf_processor.get_video_repl(
tokens_per_frame,
video_context_token=hf_processor.video_token,
tokens_per_frame=tokens_per_frame,
frames_indices=metadata["frames_indices"],
frame_duration_ms=frame_duration_ms,
tokenizer=hf_processor.tokenizer,
img_start_token_ids=hf_processor._img_start_token_ids,
img_end_token_ids=hf_processor._img_end_token_ids,
img_context_token_ids=hf_processor._img_context_token_ids,
)
if self.info.supports_video:
@ -917,6 +1052,37 @@ class NanoNemotronVLDummyInputsBuilder(
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
video = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=1,
overrides=overrides,
)[0]
video_items = []
for _ in range(num_videos):
video_metadata = {
"total_num_frames": num_frames,
"fps": 2,
"duration": num_frames / 2.0,
"video_backend": "opencv_dynamic",
"frames_indices": [i for i in range(num_frames)],
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
def get_dummy_mm_data(
self,
seq_len: int,
@ -1013,6 +1179,19 @@ class NemotronH_Nano_VL_V2(
self.config = config
self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self._img_start_token_ids = encode_tokens(
tokenizer, IMG_START, add_special_tokens=False
)
self._img_end_token_ids = encode_tokens(
tokenizer, IMG_END, add_special_tokens=False
)
self._img_context_token_ids = encode_tokens(
tokenizer, IMG_CONTEXT, add_special_tokens=False
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
@ -1043,13 +1222,28 @@ class NemotronH_Nano_VL_V2(
return x
def extract_feature(self, pixel_values):
vit_embeds = self.vision_model(pixel_values)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
# Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems
# as we don't support chunked prefill for video media
micro_batch_size = 128
n = pixel_values.shape[0]
vit_embeds_list = []
for i in range(0, n, micro_batch_size):
vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
)
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
)
vit_embeds = self.mlp1(vit_embeds)
vit_embeds_list.append(vit_embeds)
vit_embeds = torch.cat(vit_embeds_list, dim=0)
return vit_embeds
def _parse_and_validate_image_input(
@ -1117,12 +1311,15 @@ class NemotronH_Nano_VL_V2(
rows = int(image_rows * downsample_ratio // patch_size)
cols = int(image_cols * downsample_ratio // patch_size)
video_pruning_rate = self.video_pruning_rate
video_num_frames = video_input["num_patches"].tolist()
video_frames_indices = video_input["frames_indices"].split(video_num_frames)
# Calculate video feature dimensions (number of frames and
# their feature size (AKA tokens per frame))
# TODO: Maybe this can be optimized to avoid the loop?
for i, single_video_embeddings in enumerate(video_embeddings):
num_frames = video_input["num_patches"][i].item()
num_frames = video_num_frames[i]
frames_indices = video_frames_indices[i].tolist()
frame_duration_ms = video_input["frame_duration_ms"][i].item()
assert single_video_embeddings.shape[0] % num_frames == 0
if video_pruning_rate is not None and video_pruning_rate > 0.0:
@ -1151,6 +1348,8 @@ class NemotronH_Nano_VL_V2(
self._create_final_video_embeddings(
single_video_embeddings,
num_tokens_per_frame,
frames_indices,
frame_duration_ms,
),
)
@ -1160,6 +1359,8 @@ class NemotronH_Nano_VL_V2(
self,
video_embeddings: torch.Tensor,
num_tokens_per_frame: list[int],
frames_indices: list[int],
frame_duration_ms: int,
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
@ -1173,22 +1374,27 @@ class NemotronH_Nano_VL_V2(
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement text and convert to token IDs
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
num_tokens_per_frame,
IMG_CONTEXT,
).full
tokenizer = cached_tokenizer_from_config(self.model_config)
repl_token_ids = torch.tensor(
_seq2tokens(tokenizer, video_repl_text), device=device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
# special tokens to ensure consistent tokenization regardless of
# num_tokens_per_frame values.
video_repl = NanoNemotronVLProcessor.get_video_repl(
tokens_per_frame=num_tokens_per_frame,
frames_indices=frames_indices,
frame_duration_ms=frame_duration_ms,
tokenizer=tokenizer,
img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids,
)
# Get embedding token IDs for image context
embed_token_ids = torch.tensor(
encode_tokens(tokenizer, IMG_CONTEXT), device=device
)
# video_repl.full is a list of token IDs
repl_token_ids = torch.tensor(video_repl.full, device=device)
# Get embedding token IDs for image context (use pre-tokenized version)
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
# Create mask for video embedding positions
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
@ -1210,6 +1416,8 @@ class NemotronH_Nano_VL_V2(
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
frames_indices = kwargs.pop("frames_indices", None)
frame_duration_ms = kwargs.pop("frame_duration_ms", None)
if pixel_values_flat_video is None and video_embeds is None:
return None
@ -1221,13 +1429,22 @@ class NemotronH_Nano_VL_V2(
)
if pixel_values_flat_video is not None:
if torch.is_tensor(frames_indices):
frames_indices = frames_indices.flatten()
else:
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
frame_duration_ms = frame_duration_ms.flatten()
expected_h = expected_w = self.config.force_image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
num_frames = video_num_patches[0].item()
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
return NanoNemotronVLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_flat=pixel_values_flat_video,
num_patches=video_num_patches,
frames_indices=frames_indices,
frame_duration_ms=frame_duration_ms,
resolve_bindings=resolve_bindings,
)

View File

@ -43,32 +43,6 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class InputConditioner(nn.Module):
def __init__(
self,
input_scale: float,
norm_mean: norm_t,
norm_std: norm_t,
dtype: torch.dtype = None,
):
super().__init__()
self.dtype = dtype
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
def forward(self, x: torch.Tensor):
y = (x - self.norm_mean) / self.norm_std
if self.dtype is not None:
y = y.to(self.dtype)
return y
def _to_tensor(v: norm_t):
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
class ClsToken(nn.Module):
def __init__(
self,
@ -507,11 +481,6 @@ class RadioModel(nn.Module):
super().__init__()
self.config = config
self.input_conditioner = InputConditioner(
input_scale=1.0,
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
self.model = RadioInternVisionModel(
config=config,
quant_config=quant_config,
@ -525,8 +494,7 @@ class RadioModel(nn.Module):
pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None,
) -> torch.FloatTensor:
x = self.input_conditioner(pixel_values)
y = self.model(x)
y = self.model(pixel_values)
return self._extract_final(y)
def load_weights(self, weights) -> set[str]:
@ -548,6 +516,10 @@ class RadioModel(nn.Module):
# Skip buffers not used in vLLM
if sub in {"summary_idxs"}:
continue
if sub.startswith("input_conditioner."):
# we normalize in the input processor,
# based on norm and std values from the config
continue
vllm_key = None
if sub.startswith("model.patch_generator."):

View File

@ -223,7 +223,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height,
)
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos

View File

@ -13,10 +13,13 @@ import numpy.typing as npt
from PIL import Image
from vllm import envs
from vllm.logger import init_logger
from .base import MediaIO
from .image import ImageMediaIO
logger = init_logger(__name__)
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
num_frames, _, _, channels = frames.shape
@ -103,6 +106,7 @@ class OpenCVVideoBackend(VideoLoader):
cls,
data: bytes,
num_frames: int = -1,
fps: int = -1,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2
@ -116,14 +120,20 @@ class OpenCVVideoBackend(VideoLoader):
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
# resample video to target num_frames
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
num_frames = total_frames_num
frame_idx = list(range(0, num_frames))
# resample video to target num_frames and fps
# - the minimum of the two will be used
num_frames_to_sample = total_frames_num
if num_frames > 0:
num_frames_to_sample = min(num_frames, total_frames_num)
if fps > 0:
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
if num_frames_to_sample == total_frames_num:
frame_idx = list(range(0, num_frames_to_sample))
else:
uniform_sampled_frames = np.linspace(
0, total_frames_num - 1, num_frames, dtype=int
0, total_frames_num - 1, num_frames_to_sample, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
@ -132,7 +142,7 @@ class OpenCVVideoBackend(VideoLoader):
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
for idx in range(max(frame_idx) + 1):
ok = cap.grab()
if not ok:
break
@ -142,8 +152,8 @@ class OpenCVVideoBackend(VideoLoader):
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
assert i == num_frames, (
f"Expected reading {num_frames} frames, "
assert i == num_frames_to_sample, (
f"Expected reading {num_frames_to_sample} frames, "
f"but only loaded {i} frames from video."
)
@ -151,14 +161,14 @@ class OpenCVVideoBackend(VideoLoader):
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
# can cause incorrect timestamp calculation without num_frames=-1.
metadata = {
"total_num_frames": num_frames,
"fps": num_frames / duration,
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
"frames_indices": list(frame_idx),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames == total_frames_num,
"do_sample_frames": num_frames_to_sample == total_frames_num,
}
return frames, metadata

View File

@ -1735,20 +1735,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
):
curr_group_outputs = []
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
# processing multimodal data.This solves the issue with scheduler
# processing multimodal data. This solves the issue with scheduler
# putting too many video samples into a single batch. Scheduler
# uses pruned vision tokens count to compare it versus compute
# budget which is incorrect (Either input media size or non-pruned
# output vision tokens count should be considered)
curr_group_outputs = []
if self.is_multimodal_pruning_enabled and modality == "video":
micro_batch_size = 1
for i in range(0, num_items, micro_batch_size):
micro_batch_mm_inputs = dict(
(k, v[i : i + micro_batch_size])
for k, v in mm_kwargs_group.items()
# TODO(ywang96): Fix memory profiling to take EVS into account and
# remove this hack.
if (
self.is_multimodal_pruning_enabled
and modality == "video"
and num_items > 1
):
for video_mm_kwargs_item in filter(
lambda item: item.modality == "video", mm_kwargs
):
_, _, micro_batch_mm_inputs = next(
group_mm_kwargs_by_modality(
[video_mm_kwargs_item],
device=self.device,
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
)
)
micro_batch_outputs = model.get_multimodal_embeddings(