diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py new file mode 100644 index 0000000000000..1dc3188d60bd8 --- /dev/null +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.multimodal.video import sample_frames_from_video + +from ....conftest import VIDEO_ASSETS + +models = ["Qwen/Qwen2.5-VL-3B-Instruct"] +target_dtype = "bfloat16" + +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +def qwen2_5_vl_chat_template(*query): + return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 + + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ + "baby_reading": + qwen2_5_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), +}) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model, + video_pruning_rate: float, + num_frames: int, dtype: str, + max_tokens: int) -> None: + """Test EVS (Efficient Video Sampling) functionality with different + pruning rates. + """ + + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + prompts = [VIDEO_PROMPTS[0]] + videos = [sampled_vids[0]] + + # Initialize model with EVS configuration + with vllm_runner(model, + runner="generate", + max_model_len=4000, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"video": 1}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate) as vllm_model: + + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, + max_tokens, + videos=videos) + + # Basic validation that we got a response + assert len(outputs) == 1 + output_ids, output_text = outputs[0] + + # Ensure we got some output + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model, + video_pruning_rate: float, + num_frames: int, dtype: str, + max_tokens: int) -> None: + """Test EVS functionality with batched videos. + + This test validates that: + 1. The model handles batched video inputs correctly with EVS + 2. Both pruning configurations work with multiple videos + 3. The model doesn't crash when processing multiple videos simultaneously + """ + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + # Test batched videos + prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]] + videos = [sampled_vids[0], + sampled_vids[0]] # Use same video twice for testing + + # Initialize model with EVS configuration + with vllm_runner(model, + runner="generate", + max_model_len=4000, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"video": 2}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate) as vllm_model: + + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, + max_tokens, + videos=videos) + + # Basic validation that we got responses for both videos + assert len(outputs) == 2 + + for output_ids, output_text in outputs: + # Ensure we got some output for each video + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) diff --git a/vllm/config/model.py b/vllm/config/model.py index 302260e7e9936..da01d6d4480c5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -283,6 +283,7 @@ class ModelConfig: mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None interleave_mm_strings: InitVar[Optional[bool]] = None skip_mm_profiling: InitVar[Optional[bool]] = None + video_pruning_rate: InitVar[Optional[float]] = None def compute_hash(self) -> str: """ @@ -311,6 +312,7 @@ class ModelConfig: factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) + factors.append(self.video_pruning_rate) # hf_config can control how the model looks! try: @@ -338,17 +340,19 @@ class ModelConfig: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__( - self, - # Multimodal config init vars - limit_mm_per_prompt: Optional[dict[str, int]], - media_io_kwargs: Optional[dict[str, dict[str, Any]]], - mm_processor_kwargs: Optional[dict[str, Any]], - mm_processor_cache_gb: Optional[float], - mm_processor_cache_type: Optional[MMCacheType], - mm_shm_cache_max_object_size_mb: Optional[int], - mm_encoder_tp_mode: Optional[MMEncoderTPMode], - interleave_mm_strings: Optional[bool], - skip_mm_profiling: Optional[bool]) -> None: + self, + # Multimodal config init vars + limit_mm_per_prompt: Optional[dict[str, int]], + media_io_kwargs: Optional[dict[str, dict[str, Any]]], + mm_processor_kwargs: Optional[dict[str, Any]], + mm_processor_cache_gb: Optional[float], + mm_processor_cache_type: Optional[MMCacheType], + mm_shm_cache_max_object_size_mb: Optional[int], + mm_encoder_tp_mode: Optional[MMEncoderTPMode], + interleave_mm_strings: Optional[bool], + skip_mm_profiling: Optional[bool], + video_pruning_rate: Optional[float], + ) -> None: # Set the default seed to 0 in V1. # NOTE(woosuk): In V0, we set the default seed to None because the # driver worker shares the same process as the user process, and thus @@ -612,6 +616,7 @@ class ModelConfig: mm_encoder_tp_mode=mm_encoder_tp_mode, interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, + video_pruning_rate=video_pruning_rate, ) mm_config_kwargs = { diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 1b93b520f33f9..569de95799002 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -78,6 +78,11 @@ class MultiModalConfig: This reduces engine startup time but shifts the responsibility to users for estimating the peak memory usage of the activation of multimodal encoder and embedding cache.""" + video_pruning_rate: Optional[float] = None + """Sets pruning rate for video pruning via Efficient Video Sampling. + Value sits in range [0;1) and determines fraction of media tokens + from each video to be pruned. + """ def compute_hash(self) -> str: """ @@ -118,3 +123,7 @@ class MultiModalConfig: """ kwargs = self.mm_processor_kwargs or {} return kwargs | dict(inference_kwargs) + + def is_multimodal_pruning_enabled(self): + return (self.video_pruning_rate is not None + and self.video_pruning_rate > 0) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c894477d34b5a..7b5ed67d0adbb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -391,6 +391,7 @@ class EngineArgs: mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling + video_pruning_rate: float = MultiModalConfig.video_pruning_rate # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -813,6 +814,9 @@ class EngineArgs: multimodal_group.add_argument("--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]) + multimodal_group.add_argument( + "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]) + # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) lora_group = parser.add_argument_group( @@ -1032,6 +1036,7 @@ class EngineArgs: model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, + video_pruning_rate=self.video_pruning_rate, io_processor_plugin=self.io_processor_plugin, ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index e5cb5eb0bacb3..f13e590cd243b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -115,6 +115,42 @@ class SupportsMultiModal(Protocol): ... +@runtime_checkable +class SupportsMultiModalPruning(Protocol): + """The interface required for models that support returning both input + embeddings and positions. Model may require custom positions for dynamic + pruning of multimodal embeddings. + """ + supports_multimodal_pruning: ClassVar[Literal[True]] = True + + def recompute_mrope_positions( + self, input_ids: list[int], + multimodal_embeddings: MultiModalEmbeddings, + mrope_positions: torch.LongTensor, num_computed_tokens: int + ) -> tuple[MultiModalEmbeddings, Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt containing + entire sequence. + multimodal_embeddings: Tuple of multimodal embeddings that + fits into the prefill chunk that is being processed. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + ... + + @overload def supports_multimodal( model: type[object]) -> TypeIs[type[SupportsMultiModal]]: @@ -142,6 +178,25 @@ def supports_multimodal_encoder_tp_data( return getattr(model, "supports_encoder_tp_data", False) +@overload +def supports_multimodal_pruning( + model: type[object]) -> TypeIs[type[SupportsMultiModalPruning]]: + ... + + +@overload +def supports_multimodal_pruning( + model: object) -> TypeIs[SupportsMultiModalPruning]: + ... + + +def supports_multimodal_pruning( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMultiModalPruning]], + TypeIs[SupportsMultiModalPruning]]: + return getattr(model, "supports_multimodal_pruning", False) + + @runtime_checkable class SupportsScoreTemplate(Protocol): """The interface required for all models that support score template.""" diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b740e6d87b745..bd6c0b162cb42 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -25,9 +25,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence from functools import lru_cache, partial -from typing import Annotated, Callable, Literal, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -58,7 +58,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.evs import (compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions) +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -66,7 +72,8 @@ from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsQuant) + SupportsMultiModal, SupportsMultiModalPruning, + SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) @@ -86,9 +93,9 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema): - np: Number of patches - ni: Number of images - cps: Number of channels * patch_size * patch_size - + Historical context: - - pixel_values shape: (num_patches, num_channels * patch_size * + - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) formatnum_channels * patch_size * patch_size @@ -112,7 +119,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): - nf: Number of image features - hs: Hidden size - ni: Number of images - + Historical context: - image_embeds shape: (num_image_features, hidden_size) - num_image_features varies based on the number and resolution of the @@ -143,11 +150,11 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): Dimensions: - np: Number of patches - nv: Number of videos - - ctps: Number of channels * temporal_patch_size * patch_size * + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size - + Historical context: - - pixel_values_videos shape: (num_patches, num_channels * + - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format @@ -179,7 +186,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): - nf: Number of video features - hs: Hidden size - nv: Number of videos - + Historical context: - video_embeds shape: (num_video_features, hidden_size) - num_video_features varies based on the number and resolution of the @@ -905,6 +912,55 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): second_per_grid_ts=MultiModalFieldConfig.batched("video"), ) + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + + # EVS-specific code + video_pruning_rate = self.info.ctx.get_mm_config( + ).video_pruning_rate + if (modality == "video" and video_pruning_rate is not None + and video_pruning_rate > 0.0): + num_tokens = compute_retained_tokens_count( + grid_thw, + image_processor.merge_size, + video_pruning_rate, + ) + # End of EVS-specific code + + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] + @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, @@ -912,7 +968,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, - SupportsQuant): + SupportsQuant, + SupportsMultiModalPruning): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -949,6 +1006,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled()) if multimodal_config.get_limit_per_prompt("image") or \ multimodal_config.get_limit_per_prompt("video"): @@ -1090,6 +1150,36 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, return image_embeds.split(sizes) + def _postprocess_image_embeds_evs( + self, image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, + merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + def _process_video_input( self, video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: @@ -1119,6 +1209,114 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, return video_embeds.split(sizes) + def _postprocess_video_embeds_evs( + self, video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input["second_per_grid_ts"].long() + tokens_per_second = self.config.vision_config.tokens_per_second + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip(video_embeds_split, + grid_thw_list, + second_per_grid_ts): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = (multimodal_embeddings[0].device + if len(multimodal_embeddings) else mrope_positions.device) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, + device=device, + dtype=torch.long) + + # fmt: off + mm_embeddings_out = [mm[:, :-4] for mm in + multimodal_embeddings] + mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in + multimodal_embeddings] + # fmt: in + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1156,9 +1354,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + vision_embeddings = self._postprocess_image_embeds_evs( + vision_embeddings, multimodal_input + ) multimodal_embeddings += vision_embeddings if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1184,6 +1390,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) + if self.is_multimodal_pruning_enabled: + image_embeds = self._postprocess_image_embeds_evs( + image_embeds, image_input + ) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, @@ -1193,6 +1403,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, if video_input is not None: video_embeds = self._process_video_input(video_input) + if self.is_multimodal_pruning_enabled: + video_embeds = self._postprocess_video_embeds_evs( + video_embeds, video_input + ) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py new file mode 100644 index 0000000000000..056f3d9059681 --- /dev/null +++ b/vllm/multimodal/evs.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import typing + +import torch + + +def compute_retained_tokens_count(video_size_thw: torch.LongTensor, + spatial_merge_size: int, q: float) -> int: + """ + Compute the number of retained tokens for a given video. + Method ensures that we retain all the tokens from the first frame + regardless of the pruning rate. + + Args: + video_size_thw: The size of the video in the format of (T, H, W). + spatial_merge_size: The size of the spatial merge. + q: The pruning rate. + + Returns: + The number of retained tokens. + """ + T, H, W = map(int, video_size_thw) + min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size) + evs_num_tokens = int(T * min_num_tokens * (1 - q)) + return max(min_num_tokens, evs_num_tokens) + + +def compute_retention_mask( + video_embeds: torch.Tensor, + video_size_thw: torch.LongTensor, + spatial_merge_size: int, + q: float, +) -> torch.Tensor: + """ + Computes the retention mask for input video embeddings. + + Args: + video_embeds (`torch.Tensor`): The input video embeddings + of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)` + video_size_thw (`torch.LongTensor` of shape `(3)`): + The temporal, height and width of video. + spatial_merge_size: Size reduction for rows & cols dimensions. + q: (`float`): Pruning rate factor [0,1) + + Returns: + `torch.Tensor`: The retention mask for the video embeddings of + `(T * H * W // spatial_merge_size ^ 2)` shape. + """ + T, H, W = video_size_thw + + # Use reshape instead of einops to avoid graph breaks + video_embeds = video_embeds.reshape( + T, + H // spatial_merge_size, + W // spatial_merge_size, + video_embeds.size(-1), + ) + + # Core EVS + similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...], + video_embeds[:-1, ...], + dim=-1) + dissimilarity = 1 - similarity + + # Always ensure we include all tokens from the first frame + dissimilarity = torch.cat( + [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], + dim=0) + + dissimilarity_flat = dissimilarity.view(-1) + order = torch.argsort(dissimilarity_flat, + dim=-1, + descending=True, + stable=True) + retain_num_tokens = compute_retained_tokens_count(video_size_thw, + spatial_merge_size, q) + topk_indices = order[:retain_num_tokens] + + retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool) + retention_mask[topk_indices] = True + retention_mask = retention_mask.reshape(dissimilarity.size()) + + mask = retention_mask.view(-1) # "T H W -> (T H W)" + return mask + + +def compute_mrope_for_media( + video_size_thw: torch.LongTensor, + spatial_merge_size: int, + tokens_per_second: float = 1.0, + video_second_per_grid: float = 1.0, +) -> torch.Tensor: + """ + Computes the mrope for video embeddings based on the grid dimensions. + Computed mrope positions match original qwen 2.5 implementation, + but positions are built for media being the first element in sequence. + + Args: + video_size_thw: Media size (num frames, rows, cols) + spatial_merge_size: Size reduction for rows & cols dimensions. + tokens_per_second: Number of tokens per second. + video_second_per_grid: Number of seconds per video. + + Returns: + Tensor of shape `(T * H * W, 4)` where last dimension + represents mrope positions [0:3), while the last channel + contains value of llm_grid_w repeated for all positions. + """ + llm_grid_t = video_size_thw[0] + llm_grid_h = video_size_thw[1] // spatial_merge_size + llm_grid_w = video_size_thw[2] // spatial_merge_size + + t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).mul( + tokens_per_second * video_second_per_grid)).long().flatten()) + h_index = (torch.arange(llm_grid_h).view(1, -1, + 1).expand(llm_grid_t, -1, + llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten()) + llm_grid_w = (torch.tensor([llm_grid_w + ]).view(1, 1, + 1).expand(llm_grid_t, llm_grid_h, + llm_grid_w).flatten()) + + positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1) + return positions + + +def recompute_mrope_positions( + input_ids: torch.LongTensor, + multimodal_positions: list[torch.Tensor], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + vision_start_token_id: int, + image_token_id: int, + video_token_id: int, +) -> tuple[torch.LongTensor, int]: + """ + Update part of input mrope positions. + Original mrope_positions are computed incorrectly, so once we prune media + tokens we should reflect this in the mrope positions for the LLM. + + This method supports chunked prefill approach where + multimodal_embeddings are passed to LLM in chunks, so input + multimodal_embeddings may contain zero, some or even some part of all + multimodal_embeddings for a given prompt. + + Each multimodal_positions has 4 extra channels + (First 3 channels corresponds to original 3 mrope positions, last channel + is the maximum width of the media repeated). Provided multimodal_positions + do not reflect location of media position in sequence - they are computed + like the media is in the 0-th position in the sequence. + + Method works as follows: it recomputes mrope_positions starting from the + `num_computed_tokens` for `total_len_of_multimodal_embeddings` and then + shifts all text tokens that goes after total_len_of_multimodal_embeddings. + + It also handles case when multimodal_embeddings is partial + (e.g. one media is split into two prefill stages) + + Args: + input_ids: (N,) All input tokens of the prompt (entire sequence). + multimodal_positions: List of mrope positsions for each media. + mrope_positions: Existing mrope positions (4, N) for entire sequence. + num_computed_tokens: A number of computed tokens so far. + vision_start_token_id: Token indicating start of vision media. + image_token_id: Image token id + video_token_id: Video token id + + Returns: + Tuple of (mrope_positions, mrope_position_delta). + """ + + # Tensors + positions: torch.LongTensor = typing.cast( + torch.LongTensor, mrope_positions.clone()) # (3, N) + N = input_ids.numel() + + image_mask = input_ids.eq(image_token_id) + video_mask = input_ids.eq(video_token_id) + media_mask = image_mask | video_mask + text_mask = ~media_mask + + # Early exit: no media in this chunk + if len(multimodal_positions) == 0: + delta = (int((positions.max().item() + 1) - + N) if positions.numel() else -N) + return positions, delta + + total_mm_tokens = torch.count_nonzero(media_mask) + seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens]) + + # Early exit: we've updated positions for all media tokens + # (and consequently - for all remaining text tokens) + if seen_mm_tokens == total_mm_tokens: + delta = (int((positions.max().item() + 1) - + N) if positions.numel() else -N) + return positions, delta + + vision_start_indices = (input_ids == vision_start_token_id).nonzero( + as_tuple=True)[0] + + for mm_pos in multimodal_positions: + # Each mm_pos can be a complete embedding for single media + # or it can be a part of a single media (due to chunked prefill) + + # Cases to cover + # - Current prefill chunk has no vision start indexes at all + # - Vision start token appeared in previous prefill round + # - Regular case + seen_vision_start_indices = vision_start_indices[vision_start_indices < + num_computed_tokens] + + if len(seen_vision_start_indices): + # If we have encountered some vision start indexes, + # then we should check the condition: + # | --- prefill 1 ------| ---- prefill 2 ----- | + # | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT| + last_vision_start_token = seen_vision_start_indices[-1] + seem_mm_tokens_before_last_vision_start = torch.count_nonzero( + media_mask[:last_vision_start_token]) + in_the_middle_of_media = ( + seen_mm_tokens > seem_mm_tokens_before_last_vision_start) + + if in_the_middle_of_media: + mm_embeddings_seen = (seen_mm_tokens - + seem_mm_tokens_before_last_vision_start) + global_mm_start = last_vision_start_token + else: + # We have completed previous mm_embedding part and + # ready to start a new one + next_vision_start_token = vision_start_indices[ + vision_start_indices >= num_computed_tokens][0] + mm_embeddings_seen = 0 + global_mm_start = next_vision_start_token + + else: + # If there were no vision start indexes so far, + # let's find first vision start index + next_vision_start_token = vision_start_indices[ + vision_start_indices >= num_computed_tokens][0] + + mm_embeddings_seen = 0 + global_mm_start = next_vision_start_token + + # Offset right after vision_start_token + base = positions[-1, global_mm_start] + 1 + local_start = global_mm_start + 1 + mm_embeddings_seen + local_end = local_start + mm_pos.shape[1] + positions[:, local_start:local_end] = mm_pos[0:3] + base + + # mm_pos[3, 0] is the max width of the media + offset = mm_pos[3, 0] + base + + text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0) + + positions[:, local_end:N] = text_pos_sum + offset - 1 + + # Include distance to the next vision start token + num_computed_tokens += mm_pos.shape[1] + + mrope_positions_delta = (positions.max() + 1 - N).item() + return positions, mrope_positions_delta diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b7a066654d703..dca6feded12e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -40,11 +40,15 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.models.interfaces import (SupportsMultiModal, is_mixture_of_experts, supports_eagle3, supports_mrope, + supports_multimodal_pruning, supports_transcription) +# yapf: enable from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY @@ -206,7 +210,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model) - + # This will be overridden in load_model() + self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -1530,29 +1535,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # encoder outputs. model = cast(SupportsMultiModal, self.model) encoder_outputs = [] - for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, ): - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + # (ekhvedchenia): Temporary hack to limit peak memory usage when + # processing multimodal data.This solves the issue with scheduler + # putting too many video samples into a single batch. Scheduler + # uses pruned vision tokens count to compare it versus compute + # budget which is incorrect (Either input media size or non-pruned + # output vision tokens count should be considered) + curr_group_outputs = [] + + if self.is_multimodal_pruning_enabled and modality == "video": + micro_batch_size = 1 + for i in range(0, num_items, micro_batch_size): + micro_batch_mm_inputs = dict( + (k, v[i:i + micro_batch_size]) + for k, v in mm_kwargs_group.items()) + + micro_batch_outputs = model.get_multimodal_embeddings( + **micro_batch_mm_inputs) + + curr_group_outputs.extend(micro_batch_outputs) + else: + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, + # each of shape (feature_size, hidden_size) in case the feature + # size is dynamic depending on the input multimodal items. + curr_group_outputs = model.get_multimodal_embeddings( + **mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, ) - - for output in curr_group_outputs: - encoder_outputs.append(output) + encoder_outputs.extend(curr_group_outputs) # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): @@ -1566,8 +1589,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: + should_sync_mrope_positions = False mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: + mm_embeds_req: list[torch.Tensor] = [] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] @@ -1609,7 +1635,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): encoder_output[start_idx:end_idx], is_embed=is_embed, ) - mm_embeds.append(mm_embeds_item) + mm_embeds_req.append(mm_embeds_item) + + if self.is_multimodal_pruning_enabled and self.uses_mrope: + should_sync_mrope_positions = True + mm_embeds_req, new_mrope_positions, new_delta = ( + self.model.recompute_mrope_positions( + input_ids=req_state.prompt_token_ids, + multimodal_embeddings=mm_embeds_req, + mrope_positions=req_state.mrope_positions, + num_computed_tokens=req_state.num_computed_tokens, + )) + assert req_state.mrope_positions is not None + req_state.mrope_positions.copy_(new_mrope_positions) + req_state.mrope_position_delta = new_delta + + mm_embeds.extend(mm_embeds_req) + + if should_sync_mrope_positions: + self._calc_mrope_positions(scheduler_output) + self.mrope_positions.copy_to_gpu( + scheduler_output.total_num_scheduled_tokens) + return mm_embeds def _extract_encoder_inputs( @@ -2589,6 +2636,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) + self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( + self.model) and self.model_config.multimodal_config. + is_multimodal_pruning_enabled()) + if is_mixture_of_experts( self.model) and self.parallel_config.enable_eplb: logger.info("EPLB is enabled for model %s.", @@ -2843,7 +2894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_tokens: Number of tokens to run the dummy forward pass. cudagraph_runtime_mode: used to control the behavior. - - if not set will determine the cudagraph mode based on using + - if not set will determine the cudagraph mode based on using the self.cudagraph_dispatcher. - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.