EVS Support (Video tokens pruning) (#22980)

Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Eugene Khvedchenya 2025-09-26 06:54:54 +03:00 committed by GitHub
parent 983056e456
commit 392edee34a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 783 additions and 39 deletions

View File

@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal.video import sample_frames_from_video
from ....conftest import VIDEO_ASSETS
models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
target_dtype = "bfloat16"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
def qwen2_5_vl_chat_template(*query):
return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501
VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"baby_reading":
qwen2_5_vl_chat_template(
VIDEO_PLACEHOLDER,
"Describe this video with a short sentence ",
"(no more than 20 words)",
),
})
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model,
video_pruning_rate: float,
num_frames: int, dtype: str,
max_tokens: int) -> None:
"""Test EVS (Efficient Video Sampling) functionality with different
pruning rates.
"""
# Sample frames from video assets
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
prompts = [VIDEO_PROMPTS[0]]
videos = [sampled_vids[0]]
# Initialize model with EVS configuration
with vllm_runner(model,
runner="generate",
max_model_len=4000,
max_num_seqs=1,
dtype=dtype,
limit_mm_per_prompt={"video": 1},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate) as vllm_model:
# Generate output - this should not crash
outputs = vllm_model.generate_greedy(prompts,
max_tokens,
videos=videos)
# Basic validation that we got a response
assert len(outputs) == 1
output_ids, output_text = outputs[0]
# Ensure we got some output
assert len(output_ids) > 0
assert len(output_text) > 0
# Ensure the output is a string
assert isinstance(output_text, str)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model,
video_pruning_rate: float,
num_frames: int, dtype: str,
max_tokens: int) -> None:
"""Test EVS functionality with batched videos.
This test validates that:
1. The model handles batched video inputs correctly with EVS
2. Both pruning configurations work with multiple videos
3. The model doesn't crash when processing multiple videos simultaneously
"""
# Sample frames from video assets
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
# Test batched videos
prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]]
videos = [sampled_vids[0],
sampled_vids[0]] # Use same video twice for testing
# Initialize model with EVS configuration
with vllm_runner(model,
runner="generate",
max_model_len=4000,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"video": 2},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate) as vllm_model:
# Generate output - this should not crash
outputs = vllm_model.generate_greedy(prompts,
max_tokens,
videos=videos)
# Basic validation that we got responses for both videos
assert len(outputs) == 2
for output_ids, output_text in outputs:
# Ensure we got some output for each video
assert len(output_ids) > 0
assert len(output_text) > 0
# Ensure the output is a string
assert isinstance(output_text, str)

View File

@ -283,6 +283,7 @@ class ModelConfig:
mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None
interleave_mm_strings: InitVar[Optional[bool]] = None
skip_mm_profiling: InitVar[Optional[bool]] = None
video_pruning_rate: InitVar[Optional[float]] = None
def compute_hash(self) -> str:
"""
@ -311,6 +312,7 @@ class ModelConfig:
factors.append(self.override_generation_config)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
factors.append(self.video_pruning_rate)
# hf_config can control how the model looks!
try:
@ -338,17 +340,19 @@ class ModelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(
self,
# Multimodal config init vars
limit_mm_per_prompt: Optional[dict[str, int]],
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
mm_processor_kwargs: Optional[dict[str, Any]],
mm_processor_cache_gb: Optional[float],
mm_processor_cache_type: Optional[MMCacheType],
mm_shm_cache_max_object_size_mb: Optional[int],
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
interleave_mm_strings: Optional[bool],
skip_mm_profiling: Optional[bool]) -> None:
self,
# Multimodal config init vars
limit_mm_per_prompt: Optional[dict[str, int]],
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
mm_processor_kwargs: Optional[dict[str, Any]],
mm_processor_cache_gb: Optional[float],
mm_processor_cache_type: Optional[MMCacheType],
mm_shm_cache_max_object_size_mb: Optional[int],
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
interleave_mm_strings: Optional[bool],
skip_mm_profiling: Optional[bool],
video_pruning_rate: Optional[float],
) -> None:
# Set the default seed to 0 in V1.
# NOTE(woosuk): In V0, we set the default seed to None because the
# driver worker shares the same process as the user process, and thus
@ -612,6 +616,7 @@ class ModelConfig:
mm_encoder_tp_mode=mm_encoder_tp_mode,
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
)
mm_config_kwargs = {

View File

@ -78,6 +78,11 @@ class MultiModalConfig:
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache."""
video_pruning_rate: Optional[float] = None
"""Sets pruning rate for video pruning via Efficient Video Sampling.
Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned.
"""
def compute_hash(self) -> str:
"""
@ -118,3 +123,7 @@ class MultiModalConfig:
"""
kwargs = self.mm_processor_kwargs or {}
return kwargs | dict(inference_kwargs)
def is_multimodal_pruning_enabled(self):
return (self.video_pruning_rate is not None
and self.video_pruning_rate > 0)

View File

@ -391,6 +391,7 @@ class EngineArgs:
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
io_processor_plugin: Optional[str] = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -813,6 +814,9 @@ class EngineArgs:
multimodal_group.add_argument("--skip-mm-profiling",
**multimodal_kwargs["skip_mm_profiling"])
multimodal_group.add_argument(
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"])
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
lora_group = parser.add_argument_group(
@ -1032,6 +1036,7 @@ class EngineArgs:
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate,
io_processor_plugin=self.io_processor_plugin,
)

View File

@ -115,6 +115,42 @@ class SupportsMultiModal(Protocol):
...
@runtime_checkable
class SupportsMultiModalPruning(Protocol):
"""The interface required for models that support returning both input
embeddings and positions. Model may require custom positions for dynamic
pruning of multimodal embeddings.
"""
supports_multimodal_pruning: ClassVar[Literal[True]] = True
def recompute_mrope_positions(
self, input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor, num_computed_tokens: int
) -> tuple[MultiModalEmbeddings, Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt containing
entire sequence.
multimodal_embeddings: Tuple of multimodal embeddings that
fits into the prefill chunk that is being processed.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
...
@overload
def supports_multimodal(
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
@ -142,6 +178,25 @@ def supports_multimodal_encoder_tp_data(
return getattr(model, "supports_encoder_tp_data", False)
@overload
def supports_multimodal_pruning(
model: type[object]) -> TypeIs[type[SupportsMultiModalPruning]]:
...
@overload
def supports_multimodal_pruning(
model: object) -> TypeIs[SupportsMultiModalPruning]:
...
def supports_multimodal_pruning(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModalPruning]],
TypeIs[SupportsMultiModalPruning]]:
return getattr(model, "supports_multimodal_pruning", False)
@runtime_checkable
class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template."""

View File

@ -25,9 +25,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Callable, Literal, Optional, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
import torch.nn as nn
@ -58,7 +58,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.evs import (compute_mrope_for_media,
compute_retained_tokens_count,
compute_retention_mask,
recompute_mrope_positions)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@ -66,7 +72,8 @@ from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
SupportsMultiModal, SupportsMultiModalPruning,
SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
@ -86,9 +93,9 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema):
- np: Number of patches
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
Historical context:
- pixel_values shape: (num_patches, num_channels * patch_size *
- pixel_values shape: (num_patches, num_channels * patch_size *
patch_size)
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
formatnum_channels * patch_size * patch_size
@ -112,7 +119,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
Historical context:
- image_embeds shape: (num_image_features, hidden_size)
- num_image_features varies based on the number and resolution of the
@ -143,11 +150,11 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
Dimensions:
- np: Number of patches
- nv: Number of videos
- ctps: Number of channels * temporal_patch_size * patch_size *
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
Historical context:
- pixel_values_videos shape: (num_patches, num_channels *
- pixel_values_videos shape: (num_patches, num_channels *
temporal_patch_size * patch_size * patch_size)
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
@ -179,7 +186,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- nf: Number of video features
- hs: Hidden size
- nv: Number of videos
Historical context:
- video_embeds shape: (num_video_features, hidden_size)
- num_video_features varies based on the number and resolution of the
@ -905,6 +912,55 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
placeholder = {
"image": vocab[hf_processor.image_token],
"video": vocab[hf_processor.video_token],
}
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
out_item = out_mm_kwargs[modality][item_idx]
grid_thw = out_item[f"{modality}_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
# EVS-specific code
video_pruning_rate = self.info.ctx.get_mm_config(
).video_pruning_rate
if (modality == "video" and video_pruning_rate is not None
and video_pruning_rate > 0.0):
num_tokens = compute_retained_tokens_count(
grid_thw,
image_processor.merge_size,
video_pruning_rate,
)
# End of EVS-specific code
return [placeholder[modality]] * num_tokens
return [
PromptReplacement(
modality=modality,
target=[placeholder[modality]],
replacement=partial(get_replacement_qwen2vl,
modality=modality),
) for modality in ("image", "video")
]
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,
@ -912,7 +968,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP,
SupportsQuant):
SupportsQuant,
SupportsMultiModalPruning):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -949,6 +1006,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.multimodal_config = multimodal_config
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled())
if multimodal_config.get_limit_per_prompt("image") or \
multimodal_config.get_limit_per_prompt("video"):
@ -1090,6 +1150,36 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return image_embeds.split(sizes)
def _postprocess_image_embeds_evs(
self, image_embeds_split: tuple[torch.Tensor, ...],
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
"""
Append mrope positions for each for images.
This is necessary to recover correct mrope
positions after video pruning
Args:
image_embeds_split: Tuple of image embeddings for
each image item.
image_input: Image input data.
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size,
merge_size).to(emb.device)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = image_embeds_out
return tuple(image_embeds_split)
def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
@ -1119,6 +1209,114 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return video_embeds.split(sizes)
def _postprocess_video_embeds_evs(
self, video_embeds_split: tuple[torch.Tensor, ...],
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
"""
Prunes video embeddings via Efficient Video Sampling (EVS)
and then appends mrope positions for each retained embeddings
Args:
video_embeds_split: Tuple of video embeddings for each video item.
video_input: Video input data.
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
second_per_grid_ts = video_input["second_per_grid_ts"].long()
tokens_per_second = self.config.vision_config.tokens_per_second
video_embeds_out = []
for emb, size, video_second_per_grid_t in zip(video_embeds_split,
grid_thw_list,
second_per_grid_ts):
# For each video, we compute retention mask using EVS
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
positions = compute_mrope_for_media(
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
emb = emb[retention_mask]
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt (Containing
entire sequence).
multimodal_embeddings: Tuple of multimodal embeddings.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
# Device
device = (multimodal_embeddings[0].device
if len(multimodal_embeddings) else mrope_positions.device)
# Tensors
input_ids_t = torch.as_tensor(input_ids,
device=device,
dtype=torch.long)
# fmt: off
mm_embeddings_out = [mm[:, :-4] for mm in
multimodal_embeddings]
mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in
multimodal_embeddings]
# fmt: in
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
mm_embeddings_pos,
mrope_positions,
num_computed_tokens,
vision_start_token_id,
image_token_id,
video_token_id,
)
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@ -1156,9 +1354,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
vision_embeddings = self._postprocess_image_embeds_evs(
vision_embeddings, multimodal_input
)
multimodal_embeddings += vision_embeddings
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
@ -1184,6 +1390,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
if self.is_multimodal_pruning_enabled:
image_embeds = self._postprocess_image_embeds_evs(
image_embeds, image_input
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
@ -1193,6 +1403,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
if video_input is not None:
video_embeds = self._process_video_input(video_input)
if self.is_multimodal_pruning_enabled:
video_embeds = self._postprocess_video_embeds_evs(
video_embeds, video_input
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,

273
vllm/multimodal/evs.py Normal file
View File

@ -0,0 +1,273 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import typing
import torch
def compute_retained_tokens_count(video_size_thw: torch.LongTensor,
spatial_merge_size: int, q: float) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.
Args:
video_size_thw: The size of the video in the format of (T, H, W).
spatial_merge_size: The size of the spatial merge.
q: The pruning rate.
Returns:
The number of retained tokens.
"""
T, H, W = map(int, video_size_thw)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
evs_num_tokens = int(T * min_num_tokens * (1 - q))
return max(min_num_tokens, evs_num_tokens)
def compute_retention_mask(
video_embeds: torch.Tensor,
video_size_thw: torch.LongTensor,
spatial_merge_size: int,
q: float,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.
Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
video_size_thw (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
q: (`float`): Pruning rate factor [0,1)
Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = video_size_thw
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
video_embeds[:-1, ...],
dim=-1)
dissimilarity = 1 - similarity
# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
dim=0)
dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat,
dim=-1,
descending=True,
stable=True)
retain_num_tokens = compute_retained_tokens_count(video_size_thw,
spatial_merge_size, q)
topk_indices = order[:retain_num_tokens]
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())
mask = retention_mask.view(-1) # "T H W -> (T H W)"
return mask
def compute_mrope_for_media(
video_size_thw: torch.LongTensor,
spatial_merge_size: int,
tokens_per_second: float = 1.0,
video_second_per_grid: float = 1.0,
) -> torch.Tensor:
"""
Computes the mrope for video embeddings based on the grid dimensions.
Computed mrope positions match original qwen 2.5 implementation,
but positions are built for media being the first element in sequence.
Args:
video_size_thw: Media size (num frames, rows, cols)
spatial_merge_size: Size reduction for rows & cols dimensions.
tokens_per_second: Number of tokens per second.
video_second_per_grid: Number of seconds per video.
Returns:
Tensor of shape `(T * H * W, 4)` where last dimension
represents mrope positions [0:3), while the last channel
contains value of llm_grid_w repeated for all positions.
"""
llm_grid_t = video_size_thw[0]
llm_grid_h = video_size_thw[1] // spatial_merge_size
llm_grid_w = video_size_thw[2] // spatial_merge_size
t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).mul(
tokens_per_second * video_second_per_grid)).long().flatten())
h_index = (torch.arange(llm_grid_h).view(1, -1,
1).expand(llm_grid_t, -1,
llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten())
llm_grid_w = (torch.tensor([llm_grid_w
]).view(1, 1,
1).expand(llm_grid_t, llm_grid_h,
llm_grid_w).flatten())
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
return positions
def recompute_mrope_positions(
input_ids: torch.LongTensor,
multimodal_positions: list[torch.Tensor],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[torch.LongTensor, int]:
"""
Update part of input mrope positions.
Original mrope_positions are computed incorrectly, so once we prune media
tokens we should reflect this in the mrope positions for the LLM.
This method supports chunked prefill approach where
multimodal_embeddings are passed to LLM in chunks, so input
multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel
is the maximum width of the media repeated). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence.
Method works as follows: it recomputes mrope_positions starting from the
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
It also handles case when multimodal_embeddings is partial
(e.g. one media is split into two prefill stages)
Args:
input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positsions for each media.
mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media.
image_token_id: Image token id
video_token_id: Video token id
Returns:
Tuple of (mrope_positions, mrope_position_delta).
"""
# Tensors
positions: torch.LongTensor = typing.cast(
torch.LongTensor, mrope_positions.clone()) # (3, N)
N = input_ids.numel()
image_mask = input_ids.eq(image_token_id)
video_mask = input_ids.eq(video_token_id)
media_mask = image_mask | video_mask
text_mask = ~media_mask
# Early exit: no media in this chunk
if len(multimodal_positions) == 0:
delta = (int((positions.max().item() + 1) -
N) if positions.numel() else -N)
return positions, delta
total_mm_tokens = torch.count_nonzero(media_mask)
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
# Early exit: we've updated positions for all media tokens
# (and consequently - for all remaining text tokens)
if seen_mm_tokens == total_mm_tokens:
delta = (int((positions.max().item() + 1) -
N) if positions.numel() else -N)
return positions, delta
vision_start_indices = (input_ids == vision_start_token_id).nonzero(
as_tuple=True)[0]
for mm_pos in multimodal_positions:
# Each mm_pos can be a complete embedding for single media
# or it can be a part of a single media (due to chunked prefill)
# Cases to cover
# - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round
# - Regular case
seen_vision_start_indices = vision_start_indices[vision_start_indices <
num_computed_tokens]
if len(seen_vision_start_indices):
# If we have encountered some vision start indexes,
# then we should check the condition:
# | --- prefill 1 ------| ---- prefill 2 ----- |
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
last_vision_start_token = seen_vision_start_indices[-1]
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
media_mask[:last_vision_start_token])
in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start)
if in_the_middle_of_media:
mm_embeddings_seen = (seen_mm_tokens -
seem_mm_tokens_before_last_vision_start)
global_mm_start = last_vision_start_token
else:
# We have completed previous mm_embedding part and
# ready to start a new one
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
else:
# If there were no vision start indexes so far,
# let's find first vision start index
next_vision_start_token = vision_start_indices[
vision_start_indices >= num_computed_tokens][0]
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
# Offset right after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
positions[:, local_end:N] = text_pos_sum + offset - 1
# Include distance to the next vision start token
num_computed_tokens += mm_pos.shape[1]
mrope_positions_delta = (positions.max() + 1 - N).item()
return positions, mrope_positions_delta

View File

@ -40,11 +40,15 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription)
# yapf: enable
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -206,7 +210,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.enable_prompt_embeds = model_config.enable_prompt_embeds
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model)
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens
@ -1530,29 +1535,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
encoder_outputs = []
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = model.get_multimodal_embeddings(
**mm_kwargs_group)
# (ekhvedchenia): Temporary hack to limit peak memory usage when
# processing multimodal data.This solves the issue with scheduler
# putting too many video samples into a single batch. Scheduler
# uses pruned vision tokens count to compare it versus compute
# budget which is incorrect (Either input media size or non-pruned
# output vision tokens count should be considered)
curr_group_outputs = []
if self.is_multimodal_pruning_enabled and modality == "video":
micro_batch_size = 1
for i in range(0, num_items, micro_batch_size):
micro_batch_mm_inputs = dict(
(k, v[i:i + micro_batch_size])
for k, v in mm_kwargs_group.items())
micro_batch_outputs = model.get_multimodal_embeddings(
**micro_batch_mm_inputs)
curr_group_outputs.extend(micro_batch_outputs)
else:
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
curr_group_outputs = model.get_multimodal_embeddings(
**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=num_items,
)
for output in curr_group_outputs:
encoder_outputs.append(output)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
@ -1566,8 +1589,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
@ -1609,7 +1635,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions(
input_ids=req_state.prompt_token_ids,
multimodal_embeddings=mm_embeds_req,
mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens,
))
assert req_state.mrope_positions is not None
req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(
scheduler_output.total_num_scheduled_tokens)
return mm_embeds
def _extract_encoder_inputs(
@ -2589,6 +2636,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model)
self.is_multimodal_pruning_enabled = (supports_multimodal_pruning(
self.model) and self.model_config.multimodal_config.
is_multimodal_pruning_enabled())
if is_mixture_of_experts(
self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.",
@ -2843,7 +2894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- if not set will determine the cudagraph mode based on using
- if not set will determine the cudagraph mode based on using
the self.cudagraph_dispatcher.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.