[Model] EVS support for nano_nemotron_vl (#26269)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
This commit is contained in:
tomeras91 2025-10-06 19:23:37 +03:00 committed by GitHub
parent fc679696f8
commit b8f603cebe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 32 deletions

View File

@ -30,6 +30,7 @@ from vllm.model_executor.models.interfaces import (
IsHybrid,
MultiModalEmbeddings,
SupportsMultiModal,
SupportsMultiModalPruning,
)
from vllm.model_executor.models.internvl import (
calculate_internvl_targets,
@ -44,6 +45,10 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_retained_tokens_count,
compute_retention_mask,
)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
@ -62,13 +67,20 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
_seq2tokens,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
cached_tokenizer_from_config,
encode_tokens,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import _merge_multimodal_embeddings
# Configure PIL to handle large images without warnings
# This prevents DecompressionBombWarning for legitimate large images
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
@ -382,6 +394,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
video_token: Optional[str] = None,
video_pruning_rate: Optional[float] = None,
) -> None:
super().__init__(
config=config,
@ -392,6 +405,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
)
# add extra video token for video processing
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
@property
def supports_video(self) -> bool:
@ -446,12 +460,38 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
),
}
for pixel_values in pixel_values_lst_video:
num_patches = pixel_values.shape[0]
image_size: int = self.config.force_image_size
patch_size: int = self.config.patch_size
downsample_ratio = self.config.downsample_ratio
tokens_in_single_frame = int(
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
)
for pixel_values in pixel_values_lst_video:
num_frames = pixel_values.shape[0]
if (
self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
):
# Start of EVS-specific code
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_in_single_frame,
num_frames=num_frames,
q=self.video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
# End of EVS-specific code
else:
tokens_per_frame = [tokens_in_single_frame] * num_frames
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
video_repl = self.get_video_repl(
self.num_image_token, num_patches, self.video_token
)
text = [t.replace("<video>", video_repl.full, 1) for t in text]
return text, video_inputs
@ -501,20 +541,40 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
@classmethod
def get_video_repl(
self,
feature_size: int,
num_patches: Optional[int] = None,
cls,
tokens_per_frame: list[int],
video_context_token: str = IMG_CONTEXT,
) -> PromptUpdateDetails[str]:
repl_features = video_context_token * self.num_image_token
repl_features_with_sep = IMG_START + repl_features + IMG_END
# num_patches is equal to num_frames
"""
Build prompt replacement for a video.
The replacement returned is not actually used to replace the placeholder
tokens - it's just used to make sure we allocate the correct number
of tokens.
Actual replacement is done in get_multimodal_embeddings of
NemotronH_Nano_VL_V2
(specifically in _process_video_input -> _create_final_video_embeddings).
There, we create the final embeddings with text embeddings for indicator tokens
and video embeddings for video tokens.
This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
The differentiation is done via tokens_per_frame parameter.
- non EVS case - constant value same value across all frames
- EVS dummy - Doesn't matter how tokens are distributed between frames - just
make sure the total number of tokens is correct.
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
Args:
tokens_per_frame (list[int]): number of tokens per frame
video_context_token (str): the token to use for the video context
"""
repl_full = "".join(
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
[
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
for i, num_tokens in enumerate(tokens_per_frame)
]
)
return PromptUpdateDetails.select_text(repl_full, video_context_token)
return PromptUpdateDetails.from_seq(repl_full)
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
@ -605,6 +665,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
def get_video_token(self) -> Optional[str]:
return IMG_CONTEXT
def get_video_pruning_rate(self) -> Optional[float]:
return self.ctx.get_mm_config().video_pruning_rate
def get_num_frames_with_most_features(
self,
seq_len: int,
@ -628,6 +691,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(),
video_pruning_rate=self.get_video_pruning_rate(),
**kwargs,
)
@ -805,8 +869,26 @@ class NanoNemotronVLMultiModalProcessor(
if num_patches is not None:
assert isinstance(num_patches, int)
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code
num_tokens = compute_retained_tokens_count(
tokens_per_frame=feature_size,
num_frames=num_patches,
q=video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
# End of EVS-specific code
else:
tokens_per_frame = [feature_size] * num_patches
return hf_processor.get_video_repl(
feature_size, num_patches, video_context_token=hf_processor.video_token
tokens_per_frame,
video_context_token=hf_processor.video_token,
)
if self.info.supports_video:
@ -901,7 +983,9 @@ class NanoNemotronVLDummyInputsBuilder(
info=NanoNemotronVLProcessingInfo,
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
)
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal):
class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@ -913,7 +997,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
image_size = config.force_image_size
patch_size = config.patch_size
self.patch_size = patch_size
@ -924,7 +1008,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
@ -957,6 +1041,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
self.img_context_token_id = None
self.video_context_token_id = None
self.config = config
self.model_config = vllm_config.model_config
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
@ -1049,7 +1134,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
def _process_image_input(
self, image_input: NanoNemotronVLImageInputs
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -1071,6 +1156,109 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
]
return image_embeds.split(image_feature_sizes)
def _process_video_input(
self, video_input: NanoNemotronVLVideoPixelInputs
) -> tuple[torch.Tensor, ...]:
"""Process video input and create final embeddings with video content
and indicator tokens."""
# Get video embeddings using the same processing as images
video_embeddings = self._process_image_input(video_input)
final_video_embeddings: tuple[torch.Tensor, ...] = ()
image_rows = image_cols = self.config.force_image_size
downsample_ratio = self.config.downsample_ratio
patch_size = self.config.patch_size
rows = int(image_rows * downsample_ratio // patch_size)
cols = int(image_cols * downsample_ratio // patch_size)
video_pruning_rate = self.video_pruning_rate
# Calculate video feature dimensions (number of frames and
# their feature size (AKA tokens per frame))
# TODO: Maybe this can be optimized to avoid the loop?
for i, single_video_embeddings in enumerate(video_embeddings):
num_frames = video_input["num_patches"][i].item()
assert single_video_embeddings.shape[0] % num_frames == 0
if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code
retention_mask = compute_retention_mask(
single_video_embeddings,
video_size_thw=(num_frames, rows, cols),
spatial_merge_size=1,
q=video_pruning_rate,
)
# apply retention mask
single_video_embeddings = single_video_embeddings[retention_mask]
# calculate the actual number of retained tokens per frame
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
)
# End of EVS-specific code
else:
feature_size = single_video_embeddings.shape[0] // num_frames
num_tokens_per_frame = [feature_size] * num_frames
final_video_embeddings += (
self._create_final_video_embeddings(
single_video_embeddings,
num_tokens_per_frame,
),
)
return final_video_embeddings
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
num_tokens_per_frame: list[int],
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
These final embeddings contain:
- Actual video embeddings in positions corresponding to video content
- Text embeddings for indicator tokens (<img>, </img>, and
frame separation text) in their respective positions
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement text and convert to token IDs
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
num_tokens_per_frame,
IMG_CONTEXT,
).full
tokenizer = cached_tokenizer_from_config(self.model_config)
repl_token_ids = torch.tensor(
_seq2tokens(tokenizer, video_repl_text), device=device
)
# Get embedding token IDs for image context
embed_token_ids = torch.tensor(
encode_tokens(tokenizer, IMG_CONTEXT), device=device
)
# Create mask for video embedding positions
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
# Create final video embeddings, merging text embeddings for indicator
# tokens with video embeddings
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids)
final_video_embeddings = _merge_multimodal_embeddings(
inputs_embeds=text_embeddings,
multimodal_embeddings=video_embeddings,
is_multimodal=is_video_embed,
)
return final_video_embeddings
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Optional[NanoNemotronVLVideoPixelInputs]:
@ -1152,7 +1340,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings

View File

@ -1017,9 +1017,13 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
and video_pruning_rate is not None
and video_pruning_rate > 0.0
):
T, H, W = map(int, grid_thw)
tokens_per_frame = (H // image_processor.merge_size) * (
W // image_processor.merge_size
)
num_tokens = compute_retained_tokens_count(
grid_thw,
image_processor.merge_size,
tokens_per_frame,
T,
video_pruning_rate,
)
# End of EVS-specific code

View File

@ -9,12 +9,13 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import typing
from typing import Union
import torch
def compute_retained_tokens_count(
video_size_thw: torch.LongTensor, spatial_merge_size: int, q: float
tokens_per_frame: int, num_frames: int, q: float
) -> int:
"""
Compute the number of retained tokens for a given video.
@ -22,22 +23,22 @@ def compute_retained_tokens_count(
regardless of the pruning rate.
Args:
video_size_thw: The size of the video in the format of (T, H, W).
spatial_merge_size: The size of the spatial merge.
tokens_per_frame: The number of tokens per frame.
num_frames: The total number of frames.
q: The pruning rate.
Returns:
The number of retained tokens.
"""
T, H, W = map(int, video_size_thw)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
evs_num_tokens = int(T * min_num_tokens * (1 - q))
total_tokens = tokens_per_frame * num_frames
evs_num_tokens = int(total_tokens * (1 - q))
min_num_tokens = tokens_per_frame
return max(min_num_tokens, evs_num_tokens)
def compute_retention_mask(
video_embeds: torch.Tensor,
video_size_thw: torch.LongTensor,
video_size_thw: Union[torch.LongTensor, tuple[int, int, int]],
spatial_merge_size: int,
q: float,
) -> torch.Tensor:
@ -56,7 +57,7 @@ def compute_retention_mask(
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = video_size_thw
T, H, W = map(int, video_size_thw)
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
@ -65,7 +66,7 @@ def compute_retention_mask(
W // spatial_merge_size,
video_embeds.size(-1),
)
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
@ -80,7 +81,7 @@ def compute_retention_mask(
dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
retain_num_tokens = compute_retained_tokens_count(
video_size_thw, spatial_merge_size, q
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
)
topk_indices = order[:retain_num_tokens]