mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 05:15:43 +08:00
[Model] EVS support for nano_nemotron_vl (#26269)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
This commit is contained in:
parent
fc679696f8
commit
b8f603cebe
@ -30,6 +30,7 @@ from vllm.model_executor.models.interfaces import (
|
|||||||
IsHybrid,
|
IsHybrid,
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
|
SupportsMultiModalPruning,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.internvl import (
|
from vllm.model_executor.models.internvl import (
|
||||||
calculate_internvl_targets,
|
calculate_internvl_targets,
|
||||||
@ -44,6 +45,10 @@ from vllm.model_executor.models.utils import (
|
|||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.evs import (
|
||||||
|
compute_retained_tokens_count,
|
||||||
|
compute_retention_mask,
|
||||||
|
)
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
@ -62,13 +67,20 @@ from vllm.multimodal.processing import (
|
|||||||
PromptReplacement,
|
PromptReplacement,
|
||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
PromptUpdateDetails,
|
PromptUpdateDetails,
|
||||||
|
_seq2tokens,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import (
|
||||||
|
AnyTokenizer,
|
||||||
|
cached_tokenizer_from_config,
|
||||||
|
encode_tokens,
|
||||||
|
)
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
|
from .utils import _merge_multimodal_embeddings
|
||||||
|
|
||||||
# Configure PIL to handle large images without warnings
|
# Configure PIL to handle large images without warnings
|
||||||
# This prevents DecompressionBombWarning for legitimate large images
|
# This prevents DecompressionBombWarning for legitimate large images
|
||||||
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||||
@ -382,6 +394,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
max_dynamic_patch: Optional[int] = None,
|
max_dynamic_patch: Optional[int] = None,
|
||||||
dynamic_image_size: Optional[bool] = None,
|
dynamic_image_size: Optional[bool] = None,
|
||||||
video_token: Optional[str] = None,
|
video_token: Optional[str] = None,
|
||||||
|
video_pruning_rate: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config,
|
config=config,
|
||||||
@ -392,6 +405,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
)
|
)
|
||||||
# add extra video token for video processing
|
# add extra video token for video processing
|
||||||
self.video_token = video_token
|
self.video_token = video_token
|
||||||
|
self.video_pruning_rate = video_pruning_rate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_video(self) -> bool:
|
def supports_video(self) -> bool:
|
||||||
@ -446,12 +460,38 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst_video:
|
image_size: int = self.config.force_image_size
|
||||||
num_patches = pixel_values.shape[0]
|
patch_size: int = self.config.patch_size
|
||||||
|
downsample_ratio = self.config.downsample_ratio
|
||||||
|
tokens_in_single_frame = int(
|
||||||
|
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
||||||
|
)
|
||||||
|
|
||||||
|
for pixel_values in pixel_values_lst_video:
|
||||||
|
num_frames = pixel_values.shape[0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.video_pruning_rate is not None
|
||||||
|
and self.video_pruning_rate > 0.0
|
||||||
|
):
|
||||||
|
# Start of EVS-specific code
|
||||||
|
num_tokens = compute_retained_tokens_count(
|
||||||
|
tokens_per_frame=tokens_in_single_frame,
|
||||||
|
num_frames=num_frames,
|
||||||
|
q=self.video_pruning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Here we just need placeholders that won't actually be replaced -
|
||||||
|
# we just need to make sure the total number of tokens is correct
|
||||||
|
# assign all tokens to the first frame
|
||||||
|
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
|
||||||
|
|
||||||
|
# End of EVS-specific code
|
||||||
|
else:
|
||||||
|
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
||||||
|
|
||||||
|
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
|
||||||
|
|
||||||
video_repl = self.get_video_repl(
|
|
||||||
self.num_image_token, num_patches, self.video_token
|
|
||||||
)
|
|
||||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
||||||
return text, video_inputs
|
return text, video_inputs
|
||||||
|
|
||||||
@ -501,20 +541,40 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def get_video_repl(
|
def get_video_repl(
|
||||||
self,
|
cls,
|
||||||
feature_size: int,
|
tokens_per_frame: list[int],
|
||||||
num_patches: Optional[int] = None,
|
|
||||||
video_context_token: str = IMG_CONTEXT,
|
video_context_token: str = IMG_CONTEXT,
|
||||||
) -> PromptUpdateDetails[str]:
|
) -> PromptUpdateDetails[str]:
|
||||||
repl_features = video_context_token * self.num_image_token
|
"""
|
||||||
repl_features_with_sep = IMG_START + repl_features + IMG_END
|
Build prompt replacement for a video.
|
||||||
# num_patches is equal to num_frames
|
The replacement returned is not actually used to replace the placeholder
|
||||||
|
tokens - it's just used to make sure we allocate the correct number
|
||||||
|
of tokens.
|
||||||
|
Actual replacement is done in get_multimodal_embeddings of
|
||||||
|
NemotronH_Nano_VL_V2
|
||||||
|
(specifically in _process_video_input -> _create_final_video_embeddings).
|
||||||
|
There, we create the final embeddings with text embeddings for indicator tokens
|
||||||
|
and video embeddings for video tokens.
|
||||||
|
This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
|
||||||
|
The differentiation is done via tokens_per_frame parameter.
|
||||||
|
- non EVS case - constant value same value across all frames
|
||||||
|
- EVS dummy - Doesn't matter how tokens are distributed between frames - just
|
||||||
|
make sure the total number of tokens is correct.
|
||||||
|
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
||||||
|
Args:
|
||||||
|
tokens_per_frame (list[int]): number of tokens per frame
|
||||||
|
video_context_token (str): the token to use for the video context
|
||||||
|
"""
|
||||||
repl_full = "".join(
|
repl_full = "".join(
|
||||||
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
|
[
|
||||||
|
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
|
||||||
|
for i, num_tokens in enumerate(tokens_per_frame)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl_full, video_context_token)
|
return PromptUpdateDetails.from_seq(repl_full)
|
||||||
|
|
||||||
|
|
||||||
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||||
@ -605,6 +665,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
|||||||
def get_video_token(self) -> Optional[str]:
|
def get_video_token(self) -> Optional[str]:
|
||||||
return IMG_CONTEXT
|
return IMG_CONTEXT
|
||||||
|
|
||||||
|
def get_video_pruning_rate(self) -> Optional[float]:
|
||||||
|
return self.ctx.get_mm_config().video_pruning_rate
|
||||||
|
|
||||||
def get_num_frames_with_most_features(
|
def get_num_frames_with_most_features(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -628,6 +691,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
|||||||
config=self.get_hf_config(),
|
config=self.get_hf_config(),
|
||||||
tokenizer=self.get_tokenizer(),
|
tokenizer=self.get_tokenizer(),
|
||||||
video_token=self.get_video_token(),
|
video_token=self.get_video_token(),
|
||||||
|
video_pruning_rate=self.get_video_pruning_rate(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -805,8 +869,26 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
if num_patches is not None:
|
if num_patches is not None:
|
||||||
assert isinstance(num_patches, int)
|
assert isinstance(num_patches, int)
|
||||||
|
|
||||||
|
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
|
||||||
|
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||||
|
# Start of EVS-specific code
|
||||||
|
num_tokens = compute_retained_tokens_count(
|
||||||
|
tokens_per_frame=feature_size,
|
||||||
|
num_frames=num_patches,
|
||||||
|
q=video_pruning_rate,
|
||||||
|
)
|
||||||
|
# Here we just need placeholders that won't actually be replaced -
|
||||||
|
# we just need to make sure the total number of tokens is correct
|
||||||
|
# assign all tokens to the first frame
|
||||||
|
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
|
||||||
|
|
||||||
|
# End of EVS-specific code
|
||||||
|
else:
|
||||||
|
tokens_per_frame = [feature_size] * num_patches
|
||||||
|
|
||||||
return hf_processor.get_video_repl(
|
return hf_processor.get_video_repl(
|
||||||
feature_size, num_patches, video_context_token=hf_processor.video_token
|
tokens_per_frame,
|
||||||
|
video_context_token=hf_processor.video_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.info.supports_video:
|
if self.info.supports_video:
|
||||||
@ -901,7 +983,9 @@ class NanoNemotronVLDummyInputsBuilder(
|
|||||||
info=NanoNemotronVLProcessingInfo,
|
info=NanoNemotronVLProcessingInfo,
|
||||||
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
|
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal):
|
class NemotronH_Nano_VL_V2(
|
||||||
|
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
|
||||||
|
):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
if modality.startswith("image"):
|
if modality.startswith("image"):
|
||||||
@ -913,7 +997,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
image_size = config.force_image_size
|
image_size = config.force_image_size
|
||||||
patch_size = config.patch_size
|
patch_size = config.patch_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@ -924,7 +1008,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
self.downsample_ratio = config.downsample_ratio
|
self.downsample_ratio = config.downsample_ratio
|
||||||
self.ps_version = config.ps_version
|
self.ps_version = config.ps_version
|
||||||
self.image_tag_type = config.image_tag_type
|
self.image_tag_type = config.image_tag_type
|
||||||
|
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||||
self.language_model = init_vllm_registered_model(
|
self.language_model = init_vllm_registered_model(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
hf_config=config.text_config,
|
hf_config=config.text_config,
|
||||||
@ -957,6 +1041,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
self.img_context_token_id = None
|
self.img_context_token_id = None
|
||||||
self.video_context_token_id = None
|
self.video_context_token_id = None
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
|
||||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||||
n, w, h, c = x.size()
|
n, w, h, c = x.size()
|
||||||
@ -1049,7 +1134,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: NanoNemotronVLImageInputs
|
self, image_input: NanoNemotronVLImageInputs
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
return image_input["data"]
|
||||||
|
|
||||||
@ -1071,6 +1156,109 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
]
|
]
|
||||||
return image_embeds.split(image_feature_sizes)
|
return image_embeds.split(image_feature_sizes)
|
||||||
|
|
||||||
|
def _process_video_input(
|
||||||
|
self, video_input: NanoNemotronVLVideoPixelInputs
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""Process video input and create final embeddings with video content
|
||||||
|
and indicator tokens."""
|
||||||
|
# Get video embeddings using the same processing as images
|
||||||
|
video_embeddings = self._process_image_input(video_input)
|
||||||
|
|
||||||
|
final_video_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
|
image_rows = image_cols = self.config.force_image_size
|
||||||
|
downsample_ratio = self.config.downsample_ratio
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
rows = int(image_rows * downsample_ratio // patch_size)
|
||||||
|
cols = int(image_cols * downsample_ratio // patch_size)
|
||||||
|
video_pruning_rate = self.video_pruning_rate
|
||||||
|
|
||||||
|
# Calculate video feature dimensions (number of frames and
|
||||||
|
# their feature size (AKA tokens per frame))
|
||||||
|
# TODO: Maybe this can be optimized to avoid the loop?
|
||||||
|
for i, single_video_embeddings in enumerate(video_embeddings):
|
||||||
|
num_frames = video_input["num_patches"][i].item()
|
||||||
|
assert single_video_embeddings.shape[0] % num_frames == 0
|
||||||
|
|
||||||
|
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||||
|
# Start of EVS-specific code
|
||||||
|
retention_mask = compute_retention_mask(
|
||||||
|
single_video_embeddings,
|
||||||
|
video_size_thw=(num_frames, rows, cols),
|
||||||
|
spatial_merge_size=1,
|
||||||
|
q=video_pruning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply retention mask
|
||||||
|
single_video_embeddings = single_video_embeddings[retention_mask]
|
||||||
|
|
||||||
|
# calculate the actual number of retained tokens per frame
|
||||||
|
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
|
||||||
|
num_tokens_per_frame = (
|
||||||
|
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
|
||||||
|
)
|
||||||
|
# End of EVS-specific code
|
||||||
|
else:
|
||||||
|
feature_size = single_video_embeddings.shape[0] // num_frames
|
||||||
|
num_tokens_per_frame = [feature_size] * num_frames
|
||||||
|
|
||||||
|
final_video_embeddings += (
|
||||||
|
self._create_final_video_embeddings(
|
||||||
|
single_video_embeddings,
|
||||||
|
num_tokens_per_frame,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_video_embeddings
|
||||||
|
|
||||||
|
def _create_final_video_embeddings(
|
||||||
|
self,
|
||||||
|
video_embeddings: torch.Tensor,
|
||||||
|
num_tokens_per_frame: list[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create final embeddings that combine video embeddings with
|
||||||
|
text embeddings of indicator tokens.
|
||||||
|
|
||||||
|
These final embeddings contain:
|
||||||
|
- Actual video embeddings in positions corresponding to video content
|
||||||
|
- Text embeddings for indicator tokens (<img>, </img>, and
|
||||||
|
frame separation text) in their respective positions
|
||||||
|
|
||||||
|
These embeddings will replace the placeholder embeddings to create
|
||||||
|
input_embeds for the LLM.
|
||||||
|
"""
|
||||||
|
device = video_embeddings.device
|
||||||
|
|
||||||
|
# Generate video replacement text and convert to token IDs
|
||||||
|
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
|
||||||
|
num_tokens_per_frame,
|
||||||
|
IMG_CONTEXT,
|
||||||
|
).full
|
||||||
|
|
||||||
|
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||||
|
repl_token_ids = torch.tensor(
|
||||||
|
_seq2tokens(tokenizer, video_repl_text), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embedding token IDs for image context
|
||||||
|
embed_token_ids = torch.tensor(
|
||||||
|
encode_tokens(tokenizer, IMG_CONTEXT), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mask for video embedding positions
|
||||||
|
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
|
||||||
|
|
||||||
|
# Create final video embeddings, merging text embeddings for indicator
|
||||||
|
# tokens with video embeddings
|
||||||
|
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids)
|
||||||
|
final_video_embeddings = _merge_multimodal_embeddings(
|
||||||
|
inputs_embeds=text_embeddings,
|
||||||
|
multimodal_embeddings=video_embeddings,
|
||||||
|
is_multimodal=is_video_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_video_embeddings
|
||||||
|
|
||||||
def _parse_and_validate_video_input(
|
def _parse_and_validate_video_input(
|
||||||
self, **kwargs: object
|
self, **kwargs: object
|
||||||
) -> Optional[NanoNemotronVLVideoPixelInputs]:
|
) -> Optional[NanoNemotronVLVideoPixelInputs]:
|
||||||
@ -1152,7 +1340,7 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModa
|
|||||||
multimodal_embeddings += vision_embeddings
|
multimodal_embeddings += vision_embeddings
|
||||||
if modality == "videos":
|
if modality == "videos":
|
||||||
video_input = modalities["videos"]
|
video_input = modalities["videos"]
|
||||||
video_embeddings = self._process_image_input(video_input)
|
video_embeddings = self._process_video_input(video_input)
|
||||||
multimodal_embeddings += video_embeddings
|
multimodal_embeddings += video_embeddings
|
||||||
|
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|||||||
@ -1017,9 +1017,13 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
|||||||
and video_pruning_rate is not None
|
and video_pruning_rate is not None
|
||||||
and video_pruning_rate > 0.0
|
and video_pruning_rate > 0.0
|
||||||
):
|
):
|
||||||
|
T, H, W = map(int, grid_thw)
|
||||||
|
tokens_per_frame = (H // image_processor.merge_size) * (
|
||||||
|
W // image_processor.merge_size
|
||||||
|
)
|
||||||
num_tokens = compute_retained_tokens_count(
|
num_tokens = compute_retained_tokens_count(
|
||||||
grid_thw,
|
tokens_per_frame,
|
||||||
image_processor.merge_size,
|
T,
|
||||||
video_pruning_rate,
|
video_pruning_rate,
|
||||||
)
|
)
|
||||||
# End of EVS-specific code
|
# End of EVS-specific code
|
||||||
|
|||||||
@ -9,12 +9,13 @@
|
|||||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def compute_retained_tokens_count(
|
def compute_retained_tokens_count(
|
||||||
video_size_thw: torch.LongTensor, spatial_merge_size: int, q: float
|
tokens_per_frame: int, num_frames: int, q: float
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Compute the number of retained tokens for a given video.
|
Compute the number of retained tokens for a given video.
|
||||||
@ -22,22 +23,22 @@ def compute_retained_tokens_count(
|
|||||||
regardless of the pruning rate.
|
regardless of the pruning rate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_size_thw: The size of the video in the format of (T, H, W).
|
tokens_per_frame: The number of tokens per frame.
|
||||||
spatial_merge_size: The size of the spatial merge.
|
num_frames: The total number of frames.
|
||||||
q: The pruning rate.
|
q: The pruning rate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of retained tokens.
|
The number of retained tokens.
|
||||||
"""
|
"""
|
||||||
T, H, W = map(int, video_size_thw)
|
total_tokens = tokens_per_frame * num_frames
|
||||||
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
|
evs_num_tokens = int(total_tokens * (1 - q))
|
||||||
evs_num_tokens = int(T * min_num_tokens * (1 - q))
|
min_num_tokens = tokens_per_frame
|
||||||
return max(min_num_tokens, evs_num_tokens)
|
return max(min_num_tokens, evs_num_tokens)
|
||||||
|
|
||||||
|
|
||||||
def compute_retention_mask(
|
def compute_retention_mask(
|
||||||
video_embeds: torch.Tensor,
|
video_embeds: torch.Tensor,
|
||||||
video_size_thw: torch.LongTensor,
|
video_size_thw: Union[torch.LongTensor, tuple[int, int, int]],
|
||||||
spatial_merge_size: int,
|
spatial_merge_size: int,
|
||||||
q: float,
|
q: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -56,7 +57,7 @@ def compute_retention_mask(
|
|||||||
`torch.Tensor`: The retention mask for the video embeddings of
|
`torch.Tensor`: The retention mask for the video embeddings of
|
||||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||||
"""
|
"""
|
||||||
T, H, W = video_size_thw
|
T, H, W = map(int, video_size_thw)
|
||||||
|
|
||||||
# Use reshape instead of einops to avoid graph breaks
|
# Use reshape instead of einops to avoid graph breaks
|
||||||
video_embeds = video_embeds.reshape(
|
video_embeds = video_embeds.reshape(
|
||||||
@ -65,7 +66,7 @@ def compute_retention_mask(
|
|||||||
W // spatial_merge_size,
|
W // spatial_merge_size,
|
||||||
video_embeds.size(-1),
|
video_embeds.size(-1),
|
||||||
)
|
)
|
||||||
|
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||||
# Core EVS
|
# Core EVS
|
||||||
similarity = torch.nn.functional.cosine_similarity(
|
similarity = torch.nn.functional.cosine_similarity(
|
||||||
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
||||||
@ -80,7 +81,7 @@ def compute_retention_mask(
|
|||||||
dissimilarity_flat = dissimilarity.view(-1)
|
dissimilarity_flat = dissimilarity.view(-1)
|
||||||
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
||||||
retain_num_tokens = compute_retained_tokens_count(
|
retain_num_tokens = compute_retained_tokens_count(
|
||||||
video_size_thw, spatial_merge_size, q
|
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
||||||
)
|
)
|
||||||
topk_indices = order[:retain_num_tokens]
|
topk_indices = order[:retain_num_tokens]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user