mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 14:46:00 +08:00
[Model] EVS support for nano_nemotron_vl (#26269)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
This commit is contained in:
parent
fc679696f8
commit
b8f603cebe
@ -30,6 +30,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
IsHybrid,
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsMultiModalPruning,
|
||||
)
|
||||
from vllm.model_executor.models.internvl import (
|
||||
calculate_internvl_targets,
|
||||
@ -44,6 +45,10 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.evs import (
|
||||
compute_retained_tokens_count,
|
||||
compute_retention_mask,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
@ -62,13 +67,20 @@ from vllm.multimodal.processing import (
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
_seq2tokens,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
AnyTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
# Configure PIL to handle large images without warnings
|
||||
# This prevents DecompressionBombWarning for legitimate large images
|
||||
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
@ -382,6 +394,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
video_token: Optional[str] = None,
|
||||
video_pruning_rate: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
@ -392,6 +405,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
)
|
||||
# add extra video token for video processing
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
@property
|
||||
def supports_video(self) -> bool:
|
||||
@ -446,12 +460,38 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
num_patches = pixel_values.shape[0]
|
||||
image_size: int = self.config.force_image_size
|
||||
patch_size: int = self.config.patch_size
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
tokens_in_single_frame = int(
|
||||
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
||||
)
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
num_frames = pixel_values.shape[0]
|
||||
|
||||
if (
|
||||
self.video_pruning_rate is not None
|
||||
and self.video_pruning_rate > 0.0
|
||||
):
|
||||
# Start of EVS-specific code
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=tokens_in_single_frame,
|
||||
num_frames=num_frames,
|
||||
q=self.video_pruning_rate,
|
||||
)
|
||||
|
||||
# Here we just need placeholders that won't actually be replaced -
|
||||
# we just need to make sure the total number of tokens is correct
|
||||
# assign all tokens to the first frame
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
|
||||
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
||||
|
||||
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
|
||||
|
||||
video_repl = self.get_video_repl(
|
||||
self.num_image_token, num_patches, self.video_token
|
||||
)
|
||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
@ -501,20 +541,40 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
@classmethod
|
||||
def get_video_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int] = None,
|
||||
cls,
|
||||
tokens_per_frame: list[int],
|
||||
video_context_token: str = IMG_CONTEXT,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = video_context_token * self.num_image_token
|
||||
repl_features_with_sep = IMG_START + repl_features + IMG_END
|
||||
# num_patches is equal to num_frames
|
||||
"""
|
||||
Build prompt replacement for a video.
|
||||
The replacement returned is not actually used to replace the placeholder
|
||||
tokens - it's just used to make sure we allocate the correct number
|
||||
of tokens.
|
||||
Actual replacement is done in get_multimodal_embeddings of
|
||||
NemotronH_Nano_VL_V2
|
||||
(specifically in _process_video_input -> _create_final_video_embeddings).
|
||||
There, we create the final embeddings with text embeddings for indicator tokens
|
||||
and video embeddings for video tokens.
|
||||
This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
|
||||
The differentiation is done via tokens_per_frame parameter.
|
||||
- non EVS case - constant value same value across all frames
|
||||
- EVS dummy - Doesn't matter how tokens are distributed between frames - just
|
||||
make sure the total number of tokens is correct.
|
||||
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
||||
Args:
|
||||
tokens_per_frame (list[int]): number of tokens per frame
|
||||
video_context_token (str): the token to use for the video context
|
||||
"""
|
||||
repl_full = "".join(
|
||||
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
|
||||
[
|
||||
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
|
||||
for i, num_tokens in enumerate(tokens_per_frame)
|
||||
]
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, video_context_token)
|
||||
return PromptUpdateDetails.from_seq(repl_full)
|
||||
|
||||
|
||||
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
@ -605,6 +665,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
def get_video_token(self) -> Optional[str]:
|
||||
return IMG_CONTEXT
|
||||
|
||||
def get_video_pruning_rate(self) -> Optional[float]:
|
||||
return self.ctx.get_mm_config().video_pruning_rate
|
||||
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
@ -628,6 +691,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
video_token=self.get_video_token(),
|
||||
video_pruning_rate=self.get_video_pruning_rate(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -805,8 +869,26 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
# Start of EVS-specific code
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=feature_size,
|
||||
num_frames=num_patches,
|
||||
q=video_pruning_rate,
|
||||
)
|
||||
# Here we just need placeholders that won't actually be replaced -
|
||||
# we just need to make sure the total number of tokens is correct
|
||||
# assign all tokens to the first frame
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
|
||||
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
tokens_per_frame = [feature_size] * num_patches
|
||||
|
||||
return hf_processor.get_video_repl(
|
||||
feature_size, num_patches, video_context_token=hf_processor.video_token
|
||||
tokens_per_frame,
|
||||
video_context_token=hf_processor.video_token,
|
||||
)
|
||||
|
||||
if self.info.supports_video:
|
||||
@ -901,7 +983,9 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
info=NanoNemotronVLProcessingInfo,
|
||||
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
|
||||
)
|
||||
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal):
|
||||
class NemotronH_Nano_VL_V2(
|
||||
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
|
||||
):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
@ -913,7 +997,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
image_size = config.force_image_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = patch_size
|
||||
@ -924,7 +1008,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
self.image_tag_type = config.image_tag_type
|
||||
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
@ -957,6 +1041,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
self.img_context_token_id = None
|
||||
self.video_context_token_id = None
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@ -1049,7 +1134,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: NanoNemotronVLImageInputs
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@ -1071,6 +1156,109 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
]
|
||||
return image_embeds.split(image_feature_sizes)
|
||||
|
||||
def _process_video_input(
|
||||
self, video_input: NanoNemotronVLVideoPixelInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Process video input and create final embeddings with video content
|
||||
and indicator tokens."""
|
||||
# Get video embeddings using the same processing as images
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
|
||||
final_video_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
image_rows = image_cols = self.config.force_image_size
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
patch_size = self.config.patch_size
|
||||
rows = int(image_rows * downsample_ratio // patch_size)
|
||||
cols = int(image_cols * downsample_ratio // patch_size)
|
||||
video_pruning_rate = self.video_pruning_rate
|
||||
|
||||
# Calculate video feature dimensions (number of frames and
|
||||
# their feature size (AKA tokens per frame))
|
||||
# TODO: Maybe this can be optimized to avoid the loop?
|
||||
for i, single_video_embeddings in enumerate(video_embeddings):
|
||||
num_frames = video_input["num_patches"][i].item()
|
||||
assert single_video_embeddings.shape[0] % num_frames == 0
|
||||
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
# Start of EVS-specific code
|
||||
retention_mask = compute_retention_mask(
|
||||
single_video_embeddings,
|
||||
video_size_thw=(num_frames, rows, cols),
|
||||
spatial_merge_size=1,
|
||||
q=video_pruning_rate,
|
||||
)
|
||||
|
||||
# apply retention mask
|
||||
single_video_embeddings = single_video_embeddings[retention_mask]
|
||||
|
||||
# calculate the actual number of retained tokens per frame
|
||||
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
|
||||
num_tokens_per_frame = (
|
||||
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
|
||||
)
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
feature_size = single_video_embeddings.shape[0] // num_frames
|
||||
num_tokens_per_frame = [feature_size] * num_frames
|
||||
|
||||
final_video_embeddings += (
|
||||
self._create_final_video_embeddings(
|
||||
single_video_embeddings,
|
||||
num_tokens_per_frame,
|
||||
),
|
||||
)
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _create_final_video_embeddings(
|
||||
self,
|
||||
video_embeddings: torch.Tensor,
|
||||
num_tokens_per_frame: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Create final embeddings that combine video embeddings with
|
||||
text embeddings of indicator tokens.
|
||||
|
||||
These final embeddings contain:
|
||||
- Actual video embeddings in positions corresponding to video content
|
||||
- Text embeddings for indicator tokens (<img>, </img>, and
|
||||
frame separation text) in their respective positions
|
||||
|
||||
These embeddings will replace the placeholder embeddings to create
|
||||
input_embeds for the LLM.
|
||||
"""
|
||||
device = video_embeddings.device
|
||||
|
||||
# Generate video replacement text and convert to token IDs
|
||||
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
|
||||
num_tokens_per_frame,
|
||||
IMG_CONTEXT,
|
||||
).full
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
repl_token_ids = torch.tensor(
|
||||
_seq2tokens(tokenizer, video_repl_text), device=device
|
||||
)
|
||||
|
||||
# Get embedding token IDs for image context
|
||||
embed_token_ids = torch.tensor(
|
||||
encode_tokens(tokenizer, IMG_CONTEXT), device=device
|
||||
)
|
||||
|
||||
# Create mask for video embedding positions
|
||||
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
||||
|
||||
# Create final video embeddings, merging text embeddings for indicator
|
||||
# tokens with video embeddings
|
||||
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids)
|
||||
final_video_embeddings = _merge_multimodal_embeddings(
|
||||
inputs_embeds=text_embeddings,
|
||||
multimodal_embeddings=video_embeddings,
|
||||
is_multimodal=is_video_embed,
|
||||
)
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[NanoNemotronVLVideoPixelInputs]:
|
||||
@ -1152,7 +1340,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -1017,9 +1017,13 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
and video_pruning_rate is not None
|
||||
and video_pruning_rate > 0.0
|
||||
):
|
||||
T, H, W = map(int, grid_thw)
|
||||
tokens_per_frame = (H // image_processor.merge_size) * (
|
||||
W // image_processor.merge_size
|
||||
)
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
grid_thw,
|
||||
image_processor.merge_size,
|
||||
tokens_per_frame,
|
||||
T,
|
||||
video_pruning_rate,
|
||||
)
|
||||
# End of EVS-specific code
|
||||
|
||||
@ -9,12 +9,13 @@
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(
|
||||
video_size_thw: torch.LongTensor, spatial_merge_size: int, q: float
|
||||
tokens_per_frame: int, num_frames: int, q: float
|
||||
) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
@ -22,22 +23,22 @@ def compute_retained_tokens_count(
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
video_size_thw: The size of the video in the format of (T, H, W).
|
||||
spatial_merge_size: The size of the spatial merge.
|
||||
tokens_per_frame: The number of tokens per frame.
|
||||
num_frames: The total number of frames.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
evs_num_tokens = int(T * min_num_tokens * (1 - q))
|
||||
total_tokens = tokens_per_frame * num_frames
|
||||
evs_num_tokens = int(total_tokens * (1 - q))
|
||||
min_num_tokens = tokens_per_frame
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor,
|
||||
video_size_thw: Union[torch.LongTensor, tuple[int, int, int]],
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
@ -56,7 +57,7 @@ def compute_retention_mask(
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = video_size_thw
|
||||
T, H, W = map(int, video_size_thw)
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
@ -65,7 +66,7 @@ def compute_retention_mask(
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
|
||||
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
||||
@ -80,7 +81,7 @@ def compute_retention_mask(
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(
|
||||
video_size_thw, spatial_merge_size, q
|
||||
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
||||
)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user