mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:15:01 +08:00
Nemotron Nano V2 VL + EVS Video Support (#27107)
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Natan Bagrov <nbagrov@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Natan Bagrov <nbagrov@nvidia.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
1c691f4a71
commit
e93ff6c8b9
@ -14,6 +14,7 @@ from collections.abc import Iterable, Mapping, Sequence
|
|||||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||||
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -21,7 +22,7 @@ from PIL import Image
|
|||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -53,12 +54,14 @@ from vllm.multimodal.inputs import (
|
|||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalKwargs,
|
MultiModalKwargs,
|
||||||
MultiModalKwargsItems,
|
MultiModalKwargsItems,
|
||||||
|
VideoItem,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.parse import (
|
from vllm.multimodal.parse import (
|
||||||
ImageEmbeddingItems,
|
ImageEmbeddingItems,
|
||||||
ImageProcessorItems,
|
ImageProcessorItems,
|
||||||
ImageSize,
|
ImageSize,
|
||||||
MultiModalDataItems,
|
MultiModalDataItems,
|
||||||
|
MultiModalDataParser,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.processing import (
|
from vllm.multimodal.processing import (
|
||||||
BaseMultiModalProcessor,
|
BaseMultiModalProcessor,
|
||||||
@ -91,7 +94,7 @@ IMG_END = "</img>"
|
|||||||
IMG_CONTEXT = "<image>"
|
IMG_CONTEXT = "<image>"
|
||||||
|
|
||||||
# Profiling
|
# Profiling
|
||||||
MAX_FRAMES = 16
|
# MAX_FRAMES = 16
|
||||||
DEFAULT_NUM_TILES = 12
|
DEFAULT_NUM_TILES = 12
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +134,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
|||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- bvf: Batch size * number of videos * num_frames
|
- bvf: Batch size * number of videos * num_frames
|
||||||
- bn: Batch size * number of images
|
- bn: Batch size * number of videos
|
||||||
|
- f: Number of frames
|
||||||
- c: Number of channels (3)
|
- c: Number of channels (3)
|
||||||
- h: Height of each video frame
|
- h: Height of each video frame
|
||||||
- w: Width of each video frame
|
- w: Width of each video frame
|
||||||
@ -140,6 +144,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
|
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
|
||||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
frames_indices: Annotated[torch.Tensor, TensorShape("bvf")]
|
||||||
|
frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
|
class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
|
||||||
@ -251,6 +257,21 @@ def video_to_pixel_values(
|
|||||||
return torch.stack(frames_tensors)
|
return torch.stack(frames_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def input_conditioner(x, norm_mean, norm_std):
|
||||||
|
return (x - norm_mean) / norm_std
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_timestamps(
|
||||||
|
indices: list[int] | torch.Tensor,
|
||||||
|
frame_duration_ms: int,
|
||||||
|
):
|
||||||
|
if not isinstance(indices, list):
|
||||||
|
indices = indices.tolist()
|
||||||
|
|
||||||
|
timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices]
|
||||||
|
return timestamps
|
||||||
|
|
||||||
|
|
||||||
class BaseNanoNemotronVLProcessor(ABC):
|
class BaseNanoNemotronVLProcessor(ABC):
|
||||||
"""
|
"""
|
||||||
This model doesn't define its own HF processor,
|
This model doesn't define its own HF processor,
|
||||||
@ -344,17 +365,30 @@ class BaseNanoNemotronVLProcessor(ABC):
|
|||||||
else:
|
else:
|
||||||
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
|
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
|
||||||
image_inputs = {
|
image_inputs = {
|
||||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
"pixel_values_flat": input_conditioner(
|
||||||
|
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
|
||||||
|
),
|
||||||
"image_num_patches": torch.tensor(
|
"image_num_patches": torch.tensor(
|
||||||
[len(item) for item in pixel_values_lst]
|
[len(item) for item in pixel_values_lst]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst:
|
assert len(text) == 1, (
|
||||||
|
"hf_processor is called on the output of get_dummy_text, "
|
||||||
|
"which should be a single string"
|
||||||
|
)
|
||||||
|
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
|
||||||
|
assert parts.count("<image>") == len(pixel_values_lst), (
|
||||||
|
"the number of <image> tokens in the text should be the "
|
||||||
|
"same as the number of images"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, pixel_values in enumerate(pixel_values_lst):
|
||||||
num_patches = pixel_values.shape[0]
|
num_patches = pixel_values.shape[0]
|
||||||
feature_size = num_patches * self.num_image_token
|
feature_size = num_patches * self.num_image_token
|
||||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||||
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||||
|
text = ["".join(parts)]
|
||||||
return text, image_inputs
|
return text, image_inputs
|
||||||
|
|
||||||
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
||||||
@ -421,6 +455,18 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
self.video_token = video_token
|
self.video_token = video_token
|
||||||
self.video_pruning_rate = video_pruning_rate
|
self.video_pruning_rate = video_pruning_rate
|
||||||
|
|
||||||
|
# Pre-tokenize special tokens for video processing
|
||||||
|
# to avoid repeated tokenization
|
||||||
|
self._img_start_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_START, add_special_tokens=False
|
||||||
|
)
|
||||||
|
self._img_end_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_END, add_special_tokens=False
|
||||||
|
)
|
||||||
|
self._img_context_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_CONTEXT, add_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_video(self) -> bool:
|
def supports_video(self) -> bool:
|
||||||
return self.video_token_id is not None
|
return self.video_token_id is not None
|
||||||
@ -454,24 +500,43 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
def _preprocess_video(
|
def _preprocess_video(
|
||||||
self,
|
self,
|
||||||
text: list[str],
|
text: list[str],
|
||||||
videos: list[npt.NDArray],
|
videos: list[tuple[npt.NDArray, dict[str, Any]]],
|
||||||
max_num_tiles: int,
|
max_num_tiles: int,
|
||||||
dynamic_image_size: bool | None = None,
|
dynamic_image_size: bool | None = None,
|
||||||
):
|
):
|
||||||
if len(videos) == 0 or not self.supports_video:
|
if len(videos) == 0 or not self.supports_video:
|
||||||
video_inputs = {}
|
video_inputs = {}
|
||||||
else:
|
else:
|
||||||
|
videos_lst = [v[0] for v in videos]
|
||||||
|
video_metadata_lst = [v[1] for v in videos]
|
||||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||||
videos,
|
videos_lst,
|
||||||
max_num_tiles=max_num_tiles,
|
max_num_tiles=max_num_tiles,
|
||||||
dynamic_image_size=dynamic_image_size,
|
dynamic_image_size=dynamic_image_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We use frame duration in milliseconds (as integer) to ensure
|
||||||
|
# we have consistent timestamps calculation. At preprocessing
|
||||||
|
# fps parameter is given in fp32, while at inference it is bf16
|
||||||
|
# which leads to inaccurate timestamp calculation and causes
|
||||||
|
# timestamp values to differ.In rare cases this causes
|
||||||
|
# mismatching number of output tokens for tokenized frame prefixes
|
||||||
|
frame_duration_ms_lst = [
|
||||||
|
int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst
|
||||||
|
]
|
||||||
|
frames_indices_lst = [
|
||||||
|
metadata["frames_indices"] for metadata in video_metadata_lst
|
||||||
|
]
|
||||||
|
|
||||||
video_inputs = {
|
video_inputs = {
|
||||||
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
|
"pixel_values_flat_video": input_conditioner(
|
||||||
|
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
|
||||||
|
),
|
||||||
"video_num_patches": torch.tensor(
|
"video_num_patches": torch.tensor(
|
||||||
[len(item) for item in pixel_values_lst_video]
|
[len(item) for item in pixel_values_lst_video]
|
||||||
),
|
),
|
||||||
|
"frames_indices": frames_indices_lst,
|
||||||
|
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
||||||
}
|
}
|
||||||
|
|
||||||
image_size: int = self.config.force_image_size
|
image_size: int = self.config.force_image_size
|
||||||
@ -481,7 +546,12 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
||||||
)
|
)
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst_video:
|
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
|
||||||
|
pixel_values_lst_video,
|
||||||
|
video_metadata_lst,
|
||||||
|
frames_indices_lst,
|
||||||
|
frame_duration_ms_lst,
|
||||||
|
):
|
||||||
num_frames = pixel_values.shape[0]
|
num_frames = pixel_values.shape[0]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -504,16 +574,29 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
else:
|
else:
|
||||||
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
||||||
|
|
||||||
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
|
video_repl = self.get_video_repl(
|
||||||
|
tokens_per_frame=tokens_per_frame,
|
||||||
|
frames_indices=frames_indices,
|
||||||
|
frame_duration_ms=frame_duration_ms,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
img_start_token_ids=self._img_start_token_ids,
|
||||||
|
img_end_token_ids=self._img_end_token_ids,
|
||||||
|
img_context_token_ids=self._img_context_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
# video_repl.full is a list of token IDs
|
||||||
|
# Convert token IDs back to text for the HF processor flow
|
||||||
|
video_repl_text = self.tokenizer.decode(
|
||||||
|
video_repl.full, skip_special_tokens=False
|
||||||
|
)
|
||||||
|
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||||
return text, video_inputs
|
return text, video_inputs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
text: str | list[str] | None = None,
|
text: str | list[str] | None = None,
|
||||||
images: Image.Image | list[Image.Image] | None = None,
|
images: Image.Image | list[Image.Image] | None = None,
|
||||||
videos: npt.NDArray | list[npt.NDArray] | None = None,
|
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||||
return_tensors: str | TensorType | None = None,
|
return_tensors: str | TensorType | None = None,
|
||||||
max_num_tiles: int | None = None,
|
max_num_tiles: int | None = None,
|
||||||
dynamic_image_size: bool | None = None,
|
dynamic_image_size: bool | None = None,
|
||||||
@ -558,9 +641,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_video_repl(
|
def get_video_repl(
|
||||||
cls,
|
cls,
|
||||||
|
*,
|
||||||
tokens_per_frame: list[int],
|
tokens_per_frame: list[int],
|
||||||
video_context_token: str = IMG_CONTEXT,
|
frames_indices: list[int],
|
||||||
) -> PromptUpdateDetails[str]:
|
frame_duration_ms: int,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
img_start_token_ids: list[int],
|
||||||
|
img_end_token_ids: list[int],
|
||||||
|
img_context_token_ids: list[int],
|
||||||
|
) -> PromptUpdateDetails[list[int]]:
|
||||||
"""
|
"""
|
||||||
Build prompt replacement for a video.
|
Build prompt replacement for a video.
|
||||||
The replacement returned is not actually used to replace the placeholder
|
The replacement returned is not actually used to replace the placeholder
|
||||||
@ -579,16 +668,52 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
||||||
Args:
|
Args:
|
||||||
tokens_per_frame (list[int]): number of tokens per frame
|
tokens_per_frame (list[int]): number of tokens per frame
|
||||||
video_context_token (str): the token to use for the video context
|
frames_indices (list[int]): frame indices
|
||||||
|
frame_duration_ms (int): duration of each frame in milliseconds
|
||||||
|
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
|
||||||
|
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
|
||||||
|
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
|
||||||
|
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
|
||||||
"""
|
"""
|
||||||
repl_full = "".join(
|
# TODO: Add support of frame_duration_ms to be None
|
||||||
[
|
# At preprocessing step we should allow absent / metadata without
|
||||||
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
|
# frames_indices field.
|
||||||
for i, num_tokens in enumerate(tokens_per_frame)
|
timestamps_enabled = frame_duration_ms is not None
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return PromptUpdateDetails.from_seq(repl_full)
|
if timestamps_enabled:
|
||||||
|
timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(tokens_per_frame), (
|
||||||
|
"timestamps and tokens_per_frame must have the same length"
|
||||||
|
)
|
||||||
|
frame_separators = [
|
||||||
|
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
|
||||||
|
for i, timestamp in enumerate(timestamps)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
frame_separators = [
|
||||||
|
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tokenize frame separator independently
|
||||||
|
frame_separators_tokenized = [
|
||||||
|
_seq2tokens(tokenizer, sep) for sep in frame_separators
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tokenize each component independently to avoid tokenizer merging tokens
|
||||||
|
# across boundaries. This ensures consistent tokenization regardless of
|
||||||
|
# num_tokens_per_frame values.
|
||||||
|
all_token_ids = []
|
||||||
|
for i, num_tokens in enumerate(tokens_per_frame):
|
||||||
|
frame_sep_token_ids = frame_separators_tokenized[i]
|
||||||
|
all_token_ids.extend(frame_sep_token_ids)
|
||||||
|
|
||||||
|
# Add pre-tokenized special tokens
|
||||||
|
all_token_ids.extend(img_start_token_ids)
|
||||||
|
all_token_ids.extend(img_context_token_ids * num_tokens)
|
||||||
|
all_token_ids.extend(img_end_token_ids)
|
||||||
|
|
||||||
|
return PromptUpdateDetails.from_seq(all_token_ids)
|
||||||
|
|
||||||
|
|
||||||
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||||
@ -695,8 +820,6 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
|||||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
|
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
|
||||||
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||||
|
|
||||||
max_frames_per_video = min(max_frames_per_video, MAX_FRAMES)
|
|
||||||
return max(max_frames_per_video, 1)
|
return max(max_frames_per_video, 1)
|
||||||
|
|
||||||
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
||||||
@ -791,6 +914,9 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
):
|
):
|
||||||
"""MultiModalProcessor extended for video support"""
|
"""MultiModalProcessor extended for video support"""
|
||||||
|
|
||||||
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
|
return MultiModalDataParser(video_needs_metadata=True)
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
@ -805,6 +931,8 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
"video", video_num_patches
|
"video", video_num_patches
|
||||||
),
|
),
|
||||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||||
|
frames_indices=MultiModalFieldConfig.batched("video"),
|
||||||
|
frame_duration_ms=MultiModalFieldConfig.batched("video"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_fields = {}
|
video_fields = {}
|
||||||
@ -835,6 +963,7 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
|
|
||||||
def get_video_replacement_internvl(item_idx: int):
|
def get_video_replacement_internvl(item_idx: int):
|
||||||
feature_size = hf_processor.num_image_token
|
feature_size = hf_processor.num_image_token
|
||||||
|
video, metadata = mm_items["video"][item_idx]
|
||||||
num_patches = video_num_patches[item_idx]
|
num_patches = video_num_patches[item_idx]
|
||||||
if num_patches is not None:
|
if num_patches is not None:
|
||||||
assert isinstance(num_patches, int)
|
assert isinstance(num_patches, int)
|
||||||
@ -856,9 +985,15 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
else:
|
else:
|
||||||
tokens_per_frame = [feature_size] * num_patches
|
tokens_per_frame = [feature_size] * num_patches
|
||||||
|
|
||||||
|
frame_duration_ms = int(1000 / metadata["fps"])
|
||||||
return hf_processor.get_video_repl(
|
return hf_processor.get_video_repl(
|
||||||
tokens_per_frame,
|
tokens_per_frame=tokens_per_frame,
|
||||||
video_context_token=hf_processor.video_token,
|
frames_indices=metadata["frames_indices"],
|
||||||
|
frame_duration_ms=frame_duration_ms,
|
||||||
|
tokenizer=hf_processor.tokenizer,
|
||||||
|
img_start_token_ids=hf_processor._img_start_token_ids,
|
||||||
|
img_end_token_ids=hf_processor._img_end_token_ids,
|
||||||
|
img_context_token_ids=hf_processor._img_context_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.info.supports_video:
|
if self.info.supports_video:
|
||||||
@ -917,6 +1052,37 @@ class NanoNemotronVLDummyInputsBuilder(
|
|||||||
|
|
||||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||||
|
|
||||||
|
def _get_dummy_videos(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
num_frames: int,
|
||||||
|
num_videos: int,
|
||||||
|
overrides: VideoDummyOptions | None = None,
|
||||||
|
) -> list[VideoItem]:
|
||||||
|
video = super()._get_dummy_videos(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_frames=num_frames,
|
||||||
|
num_videos=1,
|
||||||
|
overrides=overrides,
|
||||||
|
)[0]
|
||||||
|
video_items = []
|
||||||
|
for _ in range(num_videos):
|
||||||
|
video_metadata = {
|
||||||
|
"total_num_frames": num_frames,
|
||||||
|
"fps": 2,
|
||||||
|
"duration": num_frames / 2.0,
|
||||||
|
"video_backend": "opencv_dynamic",
|
||||||
|
"frames_indices": [i for i in range(num_frames)],
|
||||||
|
"do_sample_frames": False,
|
||||||
|
}
|
||||||
|
video_item = (video.copy(), video_metadata)
|
||||||
|
video_items.append(video_item)
|
||||||
|
|
||||||
|
return video_items
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
def get_dummy_mm_data(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -1013,6 +1179,19 @@ class NemotronH_Nano_VL_V2(
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
|
|
||||||
|
# Pre-tokenize special tokens for video processing
|
||||||
|
# to avoid repeated tokenization
|
||||||
|
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||||
|
self._img_start_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_START, add_special_tokens=False
|
||||||
|
)
|
||||||
|
self._img_end_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_END, add_special_tokens=False
|
||||||
|
)
|
||||||
|
self._img_context_token_ids = encode_tokens(
|
||||||
|
tokenizer, IMG_CONTEXT, add_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||||
n, w, h, c = x.size()
|
n, w, h, c = x.size()
|
||||||
# N, W, H, C --> N, W, H * scale, C // scale
|
# N, W, H, C --> N, W, H * scale, C // scale
|
||||||
@ -1043,13 +1222,28 @@ class NemotronH_Nano_VL_V2(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def extract_feature(self, pixel_values):
|
def extract_feature(self, pixel_values):
|
||||||
vit_embeds = self.vision_model(pixel_values)
|
# Process images in a micro-batch of at most 128 frames per call
|
||||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
# This is done on purpose to ensure peak GPU ram usage of huge batch
|
||||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
# (namely for really long videos with EVS ON) won't cause any problems
|
||||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
# as we don't support chunked prefill for video media
|
||||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
micro_batch_size = 128
|
||||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
n = pixel_values.shape[0]
|
||||||
vit_embeds = self.mlp1(vit_embeds)
|
vit_embeds_list = []
|
||||||
|
for i in range(0, n, micro_batch_size):
|
||||||
|
vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
|
||||||
|
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||||
|
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||||
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||||
|
vit_embeds = self.pixel_shuffle(
|
||||||
|
vit_embeds, scale_factor=self.downsample_ratio
|
||||||
|
)
|
||||||
|
vit_embeds = vit_embeds.reshape(
|
||||||
|
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||||
|
)
|
||||||
|
vit_embeds = self.mlp1(vit_embeds)
|
||||||
|
vit_embeds_list.append(vit_embeds)
|
||||||
|
|
||||||
|
vit_embeds = torch.cat(vit_embeds_list, dim=0)
|
||||||
return vit_embeds
|
return vit_embeds
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
@ -1117,12 +1311,15 @@ class NemotronH_Nano_VL_V2(
|
|||||||
rows = int(image_rows * downsample_ratio // patch_size)
|
rows = int(image_rows * downsample_ratio // patch_size)
|
||||||
cols = int(image_cols * downsample_ratio // patch_size)
|
cols = int(image_cols * downsample_ratio // patch_size)
|
||||||
video_pruning_rate = self.video_pruning_rate
|
video_pruning_rate = self.video_pruning_rate
|
||||||
|
video_num_frames = video_input["num_patches"].tolist()
|
||||||
|
video_frames_indices = video_input["frames_indices"].split(video_num_frames)
|
||||||
# Calculate video feature dimensions (number of frames and
|
# Calculate video feature dimensions (number of frames and
|
||||||
# their feature size (AKA tokens per frame))
|
# their feature size (AKA tokens per frame))
|
||||||
# TODO: Maybe this can be optimized to avoid the loop?
|
# TODO: Maybe this can be optimized to avoid the loop?
|
||||||
for i, single_video_embeddings in enumerate(video_embeddings):
|
for i, single_video_embeddings in enumerate(video_embeddings):
|
||||||
num_frames = video_input["num_patches"][i].item()
|
num_frames = video_num_frames[i]
|
||||||
|
frames_indices = video_frames_indices[i].tolist()
|
||||||
|
frame_duration_ms = video_input["frame_duration_ms"][i].item()
|
||||||
assert single_video_embeddings.shape[0] % num_frames == 0
|
assert single_video_embeddings.shape[0] % num_frames == 0
|
||||||
|
|
||||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||||
@ -1151,6 +1348,8 @@ class NemotronH_Nano_VL_V2(
|
|||||||
self._create_final_video_embeddings(
|
self._create_final_video_embeddings(
|
||||||
single_video_embeddings,
|
single_video_embeddings,
|
||||||
num_tokens_per_frame,
|
num_tokens_per_frame,
|
||||||
|
frames_indices,
|
||||||
|
frame_duration_ms,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1160,6 +1359,8 @@ class NemotronH_Nano_VL_V2(
|
|||||||
self,
|
self,
|
||||||
video_embeddings: torch.Tensor,
|
video_embeddings: torch.Tensor,
|
||||||
num_tokens_per_frame: list[int],
|
num_tokens_per_frame: list[int],
|
||||||
|
frames_indices: list[int],
|
||||||
|
frame_duration_ms: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Create final embeddings that combine video embeddings with
|
"""Create final embeddings that combine video embeddings with
|
||||||
text embeddings of indicator tokens.
|
text embeddings of indicator tokens.
|
||||||
@ -1173,22 +1374,27 @@ class NemotronH_Nano_VL_V2(
|
|||||||
input_embeds for the LLM.
|
input_embeds for the LLM.
|
||||||
"""
|
"""
|
||||||
device = video_embeddings.device
|
device = video_embeddings.device
|
||||||
|
|
||||||
# Generate video replacement text and convert to token IDs
|
|
||||||
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
|
|
||||||
num_tokens_per_frame,
|
|
||||||
IMG_CONTEXT,
|
|
||||||
).full
|
|
||||||
|
|
||||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||||
repl_token_ids = torch.tensor(
|
|
||||||
_seq2tokens(tokenizer, video_repl_text), device=device
|
# Generate video replacement token IDs using get_video_repl
|
||||||
|
# This tokenizes each frame separator independently, then uses pre-tokenized
|
||||||
|
# special tokens to ensure consistent tokenization regardless of
|
||||||
|
# num_tokens_per_frame values.
|
||||||
|
video_repl = NanoNemotronVLProcessor.get_video_repl(
|
||||||
|
tokens_per_frame=num_tokens_per_frame,
|
||||||
|
frames_indices=frames_indices,
|
||||||
|
frame_duration_ms=frame_duration_ms,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
img_start_token_ids=self._img_start_token_ids,
|
||||||
|
img_end_token_ids=self._img_end_token_ids,
|
||||||
|
img_context_token_ids=self._img_context_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get embedding token IDs for image context
|
# video_repl.full is a list of token IDs
|
||||||
embed_token_ids = torch.tensor(
|
repl_token_ids = torch.tensor(video_repl.full, device=device)
|
||||||
encode_tokens(tokenizer, IMG_CONTEXT), device=device
|
|
||||||
)
|
# Get embedding token IDs for image context (use pre-tokenized version)
|
||||||
|
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
|
||||||
|
|
||||||
# Create mask for video embedding positions
|
# Create mask for video embedding positions
|
||||||
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
||||||
@ -1210,6 +1416,8 @@ class NemotronH_Nano_VL_V2(
|
|||||||
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
||||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||||
video_embeds = kwargs.pop("video_embeds", None)
|
video_embeds = kwargs.pop("video_embeds", None)
|
||||||
|
frames_indices = kwargs.pop("frames_indices", None)
|
||||||
|
frame_duration_ms = kwargs.pop("frame_duration_ms", None)
|
||||||
|
|
||||||
if pixel_values_flat_video is None and video_embeds is None:
|
if pixel_values_flat_video is None and video_embeds is None:
|
||||||
return None
|
return None
|
||||||
@ -1221,13 +1429,22 @@ class NemotronH_Nano_VL_V2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values_flat_video is not None:
|
if pixel_values_flat_video is not None:
|
||||||
|
if torch.is_tensor(frames_indices):
|
||||||
|
frames_indices = frames_indices.flatten()
|
||||||
|
else:
|
||||||
|
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
|
||||||
|
|
||||||
|
frame_duration_ms = frame_duration_ms.flatten()
|
||||||
expected_h = expected_w = self.config.force_image_size
|
expected_h = expected_w = self.config.force_image_size
|
||||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
num_frames = video_num_patches[0].item()
|
||||||
|
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
|
||||||
|
|
||||||
return NanoNemotronVLVideoPixelInputs(
|
return NanoNemotronVLVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_flat=pixel_values_flat_video,
|
pixel_values_flat=pixel_values_flat_video,
|
||||||
num_patches=video_num_patches,
|
num_patches=video_num_patches,
|
||||||
|
frames_indices=frames_indices,
|
||||||
|
frame_duration_ms=frame_duration_ms,
|
||||||
resolve_bindings=resolve_bindings,
|
resolve_bindings=resolve_bindings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -43,32 +43,6 @@ to_4tuple = _ntuple(4)
|
|||||||
to_ntuple = _ntuple
|
to_ntuple = _ntuple
|
||||||
|
|
||||||
|
|
||||||
class InputConditioner(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_scale: float,
|
|
||||||
norm_mean: norm_t,
|
|
||||||
norm_std: norm_t,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
|
|
||||||
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
y = (x - self.norm_mean) / self.norm_std
|
|
||||||
if self.dtype is not None:
|
|
||||||
y = y.to(self.dtype)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def _to_tensor(v: norm_t):
|
|
||||||
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class ClsToken(nn.Module):
|
class ClsToken(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -507,11 +481,6 @@ class RadioModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.input_conditioner = InputConditioner(
|
|
||||||
input_scale=1.0,
|
|
||||||
norm_mean=config.norm_mean,
|
|
||||||
norm_std=config.norm_std,
|
|
||||||
)
|
|
||||||
self.model = RadioInternVisionModel(
|
self.model = RadioInternVisionModel(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -525,8 +494,7 @@ class RadioModel(nn.Module):
|
|||||||
pixel_values: torch.Tensor | None = None,
|
pixel_values: torch.Tensor | None = None,
|
||||||
pixel_embeds: torch.Tensor | None = None,
|
pixel_embeds: torch.Tensor | None = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
x = self.input_conditioner(pixel_values)
|
y = self.model(pixel_values)
|
||||||
y = self.model(x)
|
|
||||||
return self._extract_final(y)
|
return self._extract_final(y)
|
||||||
|
|
||||||
def load_weights(self, weights) -> set[str]:
|
def load_weights(self, weights) -> set[str]:
|
||||||
@ -548,6 +516,10 @@ class RadioModel(nn.Module):
|
|||||||
# Skip buffers not used in vLLM
|
# Skip buffers not used in vLLM
|
||||||
if sub in {"summary_idxs"}:
|
if sub in {"summary_idxs"}:
|
||||||
continue
|
continue
|
||||||
|
if sub.startswith("input_conditioner."):
|
||||||
|
# we normalize in the input processor,
|
||||||
|
# based on norm and std values from the config
|
||||||
|
continue
|
||||||
|
|
||||||
vllm_key = None
|
vllm_key = None
|
||||||
if sub.startswith("model.patch_generator."):
|
if sub.startswith("model.patch_generator."):
|
||||||
|
|||||||
@ -223,7 +223,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
|||||||
height,
|
height,
|
||||||
)
|
)
|
||||||
height = min(height, overrides.height)
|
height = min(height, overrides.height)
|
||||||
video = np.full((num_frames, width, height, 3), 255)
|
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||||
return [video] * num_videos
|
return [video] * num_videos
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,10 +13,13 @@ import numpy.typing as npt
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
from .image import ImageMediaIO
|
from .image import ImageMediaIO
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||||
num_frames, _, _, channels = frames.shape
|
num_frames, _, _, channels = frames.shape
|
||||||
@ -103,6 +106,7 @@ class OpenCVVideoBackend(VideoLoader):
|
|||||||
cls,
|
cls,
|
||||||
data: bytes,
|
data: bytes,
|
||||||
num_frames: int = -1,
|
num_frames: int = -1,
|
||||||
|
fps: int = -1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||||
import cv2
|
import cv2
|
||||||
@ -116,14 +120,20 @@ class OpenCVVideoBackend(VideoLoader):
|
|||||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||||
|
|
||||||
# resample video to target num_frames
|
# resample video to target num_frames and fps
|
||||||
full_read = num_frames == -1 or total_frames_num < num_frames
|
# - the minimum of the two will be used
|
||||||
if full_read:
|
num_frames_to_sample = total_frames_num
|
||||||
num_frames = total_frames_num
|
if num_frames > 0:
|
||||||
frame_idx = list(range(0, num_frames))
|
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||||
|
if fps > 0:
|
||||||
|
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||||
|
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||||
|
|
||||||
|
if num_frames_to_sample == total_frames_num:
|
||||||
|
frame_idx = list(range(0, num_frames_to_sample))
|
||||||
else:
|
else:
|
||||||
uniform_sampled_frames = np.linspace(
|
uniform_sampled_frames = np.linspace(
|
||||||
0, total_frames_num - 1, num_frames, dtype=int
|
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||||
)
|
)
|
||||||
frame_idx = uniform_sampled_frames.tolist()
|
frame_idx = uniform_sampled_frames.tolist()
|
||||||
|
|
||||||
@ -132,7 +142,7 @@ class OpenCVVideoBackend(VideoLoader):
|
|||||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for idx in range(total_frames_num):
|
for idx in range(max(frame_idx) + 1):
|
||||||
ok = cap.grab()
|
ok = cap.grab()
|
||||||
if not ok:
|
if not ok:
|
||||||
break
|
break
|
||||||
@ -142,8 +152,8 @@ class OpenCVVideoBackend(VideoLoader):
|
|||||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
assert i == num_frames, (
|
assert i == num_frames_to_sample, (
|
||||||
f"Expected reading {num_frames} frames, "
|
f"Expected reading {num_frames_to_sample} frames, "
|
||||||
f"but only loaded {i} frames from video."
|
f"but only loaded {i} frames from video."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,14 +161,14 @@ class OpenCVVideoBackend(VideoLoader):
|
|||||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||||
metadata = {
|
metadata = {
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": total_frames_num,
|
||||||
"fps": num_frames / duration,
|
"fps": original_fps,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"video_backend": "opencv",
|
"video_backend": "opencv",
|
||||||
"frames_indices": list(range(num_frames)),
|
"frames_indices": list(frame_idx),
|
||||||
# extra field used to control hf processor's video
|
# extra field used to control hf processor's video
|
||||||
# sampling behavior
|
# sampling behavior
|
||||||
"do_sample_frames": num_frames == total_frames_num,
|
"do_sample_frames": num_frames_to_sample == total_frames_num,
|
||||||
}
|
}
|
||||||
|
|
||||||
return frames, metadata
|
return frames, metadata
|
||||||
|
|||||||
@ -1735,20 +1735,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
):
|
):
|
||||||
|
curr_group_outputs = []
|
||||||
|
|
||||||
|
# EVS-related change.
|
||||||
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
||||||
# processing multimodal data.This solves the issue with scheduler
|
# processing multimodal data. This solves the issue with scheduler
|
||||||
# putting too many video samples into a single batch. Scheduler
|
# putting too many video samples into a single batch. Scheduler
|
||||||
# uses pruned vision tokens count to compare it versus compute
|
# uses pruned vision tokens count to compare it versus compute
|
||||||
# budget which is incorrect (Either input media size or non-pruned
|
# budget which is incorrect (Either input media size or non-pruned
|
||||||
# output vision tokens count should be considered)
|
# output vision tokens count should be considered)
|
||||||
curr_group_outputs = []
|
# TODO(ywang96): Fix memory profiling to take EVS into account and
|
||||||
|
# remove this hack.
|
||||||
if self.is_multimodal_pruning_enabled and modality == "video":
|
if (
|
||||||
micro_batch_size = 1
|
self.is_multimodal_pruning_enabled
|
||||||
for i in range(0, num_items, micro_batch_size):
|
and modality == "video"
|
||||||
micro_batch_mm_inputs = dict(
|
and num_items > 1
|
||||||
(k, v[i : i + micro_batch_size])
|
):
|
||||||
for k, v in mm_kwargs_group.items()
|
for video_mm_kwargs_item in filter(
|
||||||
|
lambda item: item.modality == "video", mm_kwargs
|
||||||
|
):
|
||||||
|
_, _, micro_batch_mm_inputs = next(
|
||||||
|
group_mm_kwargs_by_modality(
|
||||||
|
[video_mm_kwargs_item],
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
micro_batch_outputs = model.get_multimodal_embeddings(
|
micro_batch_outputs = model.get_multimodal_embeddings(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user