mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:05:35 +08:00
EVS Support (Video tokens pruning) (#22980)
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
983056e456
commit
392edee34a
132
tests/models/multimodal/generation/test_qwen2_5_vl.py
Normal file
132
tests/models/multimodal/generation/test_qwen2_5_vl.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
|
||||
from ....conftest import VIDEO_ASSETS
|
||||
|
||||
models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
|
||||
def qwen2_5_vl_chat_template(*query):
|
||||
return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501
|
||||
|
||||
|
||||
VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
|
||||
"baby_reading":
|
||||
qwen2_5_vl_chat_template(
|
||||
VIDEO_PLACEHOLDER,
|
||||
"Describe this video with a short sentence ",
|
||||
"(no more than 20 words)",
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model,
|
||||
video_pruning_rate: float,
|
||||
num_frames: int, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
"""Test EVS (Efficient Video Sampling) functionality with different
|
||||
pruning rates.
|
||||
"""
|
||||
|
||||
# Sample frames from video assets
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
prompts = [VIDEO_PROMPTS[0]]
|
||||
videos = [sampled_vids[0]]
|
||||
|
||||
# Initialize model with EVS configuration
|
||||
with vllm_runner(model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"video": 1},
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=video_pruning_rate) as vllm_model:
|
||||
|
||||
# Generate output - this should not crash
|
||||
outputs = vllm_model.generate_greedy(prompts,
|
||||
max_tokens,
|
||||
videos=videos)
|
||||
|
||||
# Basic validation that we got a response
|
||||
assert len(outputs) == 1
|
||||
output_ids, output_text = outputs[0]
|
||||
|
||||
# Ensure we got some output
|
||||
assert len(output_ids) > 0
|
||||
assert len(output_text) > 0
|
||||
|
||||
# Ensure the output is a string
|
||||
assert isinstance(output_text, str)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model,
|
||||
video_pruning_rate: float,
|
||||
num_frames: int, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
"""Test EVS functionality with batched videos.
|
||||
|
||||
This test validates that:
|
||||
1. The model handles batched video inputs correctly with EVS
|
||||
2. Both pruning configurations work with multiple videos
|
||||
3. The model doesn't crash when processing multiple videos simultaneously
|
||||
"""
|
||||
# Sample frames from video assets
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
# Test batched videos
|
||||
prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]]
|
||||
videos = [sampled_vids[0],
|
||||
sampled_vids[0]] # Use same video twice for testing
|
||||
|
||||
# Initialize model with EVS configuration
|
||||
with vllm_runner(model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"video": 2},
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=video_pruning_rate) as vllm_model:
|
||||
|
||||
# Generate output - this should not crash
|
||||
outputs = vllm_model.generate_greedy(prompts,
|
||||
max_tokens,
|
||||
videos=videos)
|
||||
|
||||
# Basic validation that we got responses for both videos
|
||||
assert len(outputs) == 2
|
||||
|
||||
for output_ids, output_text in outputs:
|
||||
# Ensure we got some output for each video
|
||||
assert len(output_ids) > 0
|
||||
assert len(output_text) > 0
|
||||
|
||||
# Ensure the output is a string
|
||||
assert isinstance(output_text, str)
|
||||
@ -283,6 +283,7 @@ class ModelConfig:
|
||||
mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None
|
||||
interleave_mm_strings: InitVar[Optional[bool]] = None
|
||||
skip_mm_profiling: InitVar[Optional[bool]] = None
|
||||
video_pruning_rate: InitVar[Optional[float]] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -311,6 +312,7 @@ class ModelConfig:
|
||||
factors.append(self.override_generation_config)
|
||||
factors.append(self.rope_scaling)
|
||||
factors.append(self.rope_theta)
|
||||
factors.append(self.video_pruning_rate)
|
||||
|
||||
# hf_config can control how the model looks!
|
||||
try:
|
||||
@ -338,17 +340,19 @@ class ModelConfig:
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
# Multimodal config init vars
|
||||
limit_mm_per_prompt: Optional[dict[str, int]],
|
||||
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
mm_processor_cache_gb: Optional[float],
|
||||
mm_processor_cache_type: Optional[MMCacheType],
|
||||
mm_shm_cache_max_object_size_mb: Optional[int],
|
||||
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
|
||||
interleave_mm_strings: Optional[bool],
|
||||
skip_mm_profiling: Optional[bool]) -> None:
|
||||
self,
|
||||
# Multimodal config init vars
|
||||
limit_mm_per_prompt: Optional[dict[str, int]],
|
||||
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
mm_processor_cache_gb: Optional[float],
|
||||
mm_processor_cache_type: Optional[MMCacheType],
|
||||
mm_shm_cache_max_object_size_mb: Optional[int],
|
||||
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
|
||||
interleave_mm_strings: Optional[bool],
|
||||
skip_mm_profiling: Optional[bool],
|
||||
video_pruning_rate: Optional[float],
|
||||
) -> None:
|
||||
# Set the default seed to 0 in V1.
|
||||
# NOTE(woosuk): In V0, we set the default seed to None because the
|
||||
# driver worker shares the same process as the user process, and thus
|
||||
@ -612,6 +616,7 @@ class ModelConfig:
|
||||
mm_encoder_tp_mode=mm_encoder_tp_mode,
|
||||
interleave_mm_strings=interleave_mm_strings,
|
||||
skip_mm_profiling=skip_mm_profiling,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
)
|
||||
|
||||
mm_config_kwargs = {
|
||||
|
||||
@ -78,6 +78,11 @@ class MultiModalConfig:
|
||||
This reduces engine startup time but shifts the responsibility to users for
|
||||
estimating the peak memory usage of the activation of multimodal encoder and
|
||||
embedding cache."""
|
||||
video_pruning_rate: Optional[float] = None
|
||||
"""Sets pruning rate for video pruning via Efficient Video Sampling.
|
||||
Value sits in range [0;1) and determines fraction of media tokens
|
||||
from each video to be pruned.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -118,3 +123,7 @@ class MultiModalConfig:
|
||||
"""
|
||||
kwargs = self.mm_processor_kwargs or {}
|
||||
return kwargs | dict(inference_kwargs)
|
||||
|
||||
def is_multimodal_pruning_enabled(self):
|
||||
return (self.video_pruning_rate is not None
|
||||
and self.video_pruning_rate > 0)
|
||||
|
||||
@ -391,6 +391,7 @@ class EngineArgs:
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
io_processor_plugin: Optional[str] = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
@ -813,6 +814,9 @@ class EngineArgs:
|
||||
multimodal_group.add_argument("--skip-mm-profiling",
|
||||
**multimodal_kwargs["skip_mm_profiling"])
|
||||
|
||||
multimodal_group.add_argument(
|
||||
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"])
|
||||
|
||||
# LoRA related configs
|
||||
lora_kwargs = get_kwargs(LoRAConfig)
|
||||
lora_group = parser.add_argument_group(
|
||||
@ -1032,6 +1036,7 @@ class EngineArgs:
|
||||
model_impl=self.model_impl,
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
logits_processors=self.logits_processors,
|
||||
video_pruning_rate=self.video_pruning_rate,
|
||||
io_processor_plugin=self.io_processor_plugin,
|
||||
)
|
||||
|
||||
|
||||
@ -115,6 +115,42 @@ class SupportsMultiModal(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMultiModalPruning(Protocol):
|
||||
"""The interface required for models that support returning both input
|
||||
embeddings and positions. Model may require custom positions for dynamic
|
||||
pruning of multimodal embeddings.
|
||||
"""
|
||||
supports_multimodal_pruning: ClassVar[Literal[True]] = True
|
||||
|
||||
def recompute_mrope_positions(
|
||||
self, input_ids: list[int],
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
mrope_positions: torch.LongTensor, num_computed_tokens: int
|
||||
) -> tuple[MultiModalEmbeddings, Tensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions (starting with
|
||||
num_computed_tokens index). Original mrope_positions are computed
|
||||
for unpruned sequence and becomes incorrect once pruning occurs,
|
||||
so once we prune media tokens we should reflect this in the
|
||||
mrope_positions before we feed it to LLM.
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt containing
|
||||
entire sequence.
|
||||
multimodal_embeddings: Tuple of multimodal embeddings that
|
||||
fits into the prefill chunk that is being processed.
|
||||
mrope_positions: Existing mrope positions (3, N) for entire
|
||||
sequence
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
|
||||
Returns:
|
||||
Tuple of (multimodal_embeddings, mrope_positions,
|
||||
mrope_position_delta).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
|
||||
@ -142,6 +178,25 @@ def supports_multimodal_encoder_tp_data(
|
||||
return getattr(model, "supports_encoder_tp_data", False)
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_pruning(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModalPruning]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_pruning(
|
||||
model: object) -> TypeIs[SupportsMultiModalPruning]:
|
||||
...
|
||||
|
||||
|
||||
def supports_multimodal_pruning(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsMultiModalPruning]],
|
||||
TypeIs[SupportsMultiModalPruning]]:
|
||||
return getattr(model, "supports_multimodal_pruning", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsScoreTemplate(Protocol):
|
||||
"""The interface required for all models that support score template."""
|
||||
|
||||
@ -25,9 +25,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -58,7 +58,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.evs import (compute_mrope_for_media,
|
||||
compute_retained_tokens_count,
|
||||
compute_retention_mask,
|
||||
recompute_mrope_positions)
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -66,7 +72,8 @@ from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
SupportsMultiModal, SupportsMultiModalPruning,
|
||||
SupportsPP, SupportsQuant)
|
||||
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision)
|
||||
@ -86,9 +93,9 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema):
|
||||
- np: Number of patches
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
formatnum_channels * patch_size * patch_size
|
||||
@ -112,7 +119,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
@ -143,11 +150,11 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- nv: Number of videos
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
|
||||
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
@ -179,7 +186,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
@ -905,6 +912,55 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self.info.get_image_processor(
|
||||
**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
placeholder = {
|
||||
"image": vocab[hf_processor.image_token],
|
||||
"video": vocab[hf_processor.video_token],
|
||||
}
|
||||
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
def get_replacement_qwen2vl(item_idx: int, modality: str):
|
||||
out_item = out_mm_kwargs[modality][item_idx]
|
||||
grid_thw = out_item[f"{modality}_grid_thw"].data
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
|
||||
num_tokens = int(grid_thw.prod()) // merge_length
|
||||
|
||||
# EVS-specific code
|
||||
video_pruning_rate = self.info.ctx.get_mm_config(
|
||||
).video_pruning_rate
|
||||
if (modality == "video" and video_pruning_rate is not None
|
||||
and video_pruning_rate > 0.0):
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
grid_thw,
|
||||
image_processor.merge_size,
|
||||
video_pruning_rate,
|
||||
)
|
||||
# End of EVS-specific code
|
||||
|
||||
return [placeholder[modality]] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality=modality,
|
||||
target=[placeholder[modality]],
|
||||
replacement=partial(get_replacement_qwen2vl,
|
||||
modality=modality),
|
||||
) for modality in ("image", "video")
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
@ -912,7 +968,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP,
|
||||
SupportsQuant):
|
||||
SupportsQuant,
|
||||
SupportsMultiModalPruning):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -949,6 +1006,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled())
|
||||
|
||||
if multimodal_config.get_limit_per_prompt("image") or \
|
||||
multimodal_config.get_limit_per_prompt("video"):
|
||||
@ -1090,6 +1150,36 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return image_embeds.split(sizes)
|
||||
|
||||
def _postprocess_image_embeds_evs(
|
||||
self, image_embeds_split: tuple[torch.Tensor, ...],
|
||||
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Append mrope positions for each for images.
|
||||
This is necessary to recover correct mrope
|
||||
positions after video pruning
|
||||
|
||||
Args:
|
||||
image_embeds_split: Tuple of image embeddings for
|
||||
each image item.
|
||||
image_input: Image input data.
|
||||
|
||||
Returns:
|
||||
Tuple of image embeddings for each image item.
|
||||
Resulting embeddings will have extra 4 channels for
|
||||
computed mrope positions.
|
||||
"""
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
image_embeds_out = []
|
||||
for emb, size in zip(image_embeds_split, grid_thw_list):
|
||||
positions = compute_mrope_for_media(size,
|
||||
merge_size).to(emb.device)
|
||||
emb = torch.cat([emb, positions], dim=1)
|
||||
image_embeds_out.append(emb)
|
||||
image_embeds_split = image_embeds_out
|
||||
return tuple(image_embeds_split)
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
|
||||
@ -1119,6 +1209,114 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
def _postprocess_video_embeds_evs(
|
||||
self, video_embeds_split: tuple[torch.Tensor, ...],
|
||||
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Prunes video embeddings via Efficient Video Sampling (EVS)
|
||||
and then appends mrope positions for each retained embeddings
|
||||
|
||||
Args:
|
||||
video_embeds_split: Tuple of video embeddings for each video item.
|
||||
video_input: Video input data.
|
||||
|
||||
Returns:
|
||||
Tuple of video embeddings for each video item.
|
||||
Resulting embeddings will have extra 4 channels for
|
||||
computed mrope positions.
|
||||
"""
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
# Cast to long to match the original code
|
||||
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
|
||||
second_per_grid_ts = video_input["second_per_grid_ts"].long()
|
||||
tokens_per_second = self.config.vision_config.tokens_per_second
|
||||
|
||||
video_embeds_out = []
|
||||
for emb, size, video_second_per_grid_t in zip(video_embeds_split,
|
||||
grid_thw_list,
|
||||
second_per_grid_ts):
|
||||
# For each video, we compute retention mask using EVS
|
||||
retention_mask = compute_retention_mask(
|
||||
emb,
|
||||
size,
|
||||
spatial_merge_size=self.visual.spatial_merge_size,
|
||||
q=self.video_pruning_rate,
|
||||
)
|
||||
positions = compute_mrope_for_media(
|
||||
size,
|
||||
merge_size,
|
||||
tokens_per_second=tokens_per_second,
|
||||
video_second_per_grid=video_second_per_grid_t.item(),
|
||||
).to(emb.device)
|
||||
|
||||
emb = emb[retention_mask]
|
||||
positions = positions[retention_mask]
|
||||
emb = torch.cat([emb, positions], dim=1)
|
||||
video_embeds_out.append(emb)
|
||||
return tuple(video_embeds_out)
|
||||
|
||||
def recompute_mrope_positions(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions (starting with
|
||||
num_computed_tokens index). Original mrope_positions are computed
|
||||
for unpruned sequence and becomes incorrect once pruning occurs,
|
||||
so once we prune media tokens we should reflect this in the
|
||||
mrope_positions before we feed it to LLM.
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (Containing
|
||||
entire sequence).
|
||||
multimodal_embeddings: Tuple of multimodal embeddings.
|
||||
mrope_positions: Existing mrope positions (3, N) for entire
|
||||
sequence
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
|
||||
Returns:
|
||||
Tuple of (multimodal_embeddings, mrope_positions,
|
||||
mrope_position_delta).
|
||||
"""
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
# Device
|
||||
device = (multimodal_embeddings[0].device
|
||||
if len(multimodal_embeddings) else mrope_positions.device)
|
||||
|
||||
# Tensors
|
||||
input_ids_t = torch.as_tensor(input_ids,
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
|
||||
# fmt: off
|
||||
mm_embeddings_out = [mm[:, :-4] for mm in
|
||||
multimodal_embeddings]
|
||||
mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in
|
||||
multimodal_embeddings]
|
||||
# fmt: in
|
||||
|
||||
positions, mrope_positions_delta = recompute_mrope_positions(
|
||||
input_ids_t,
|
||||
mm_embeddings_pos,
|
||||
mrope_positions,
|
||||
num_computed_tokens,
|
||||
vision_start_token_id,
|
||||
image_token_id,
|
||||
video_token_id,
|
||||
)
|
||||
|
||||
return tuple(mm_embeddings_out), positions, mrope_positions_delta
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
|
||||
@ -1156,9 +1354,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_input = mm_input_by_modality[modality]
|
||||
if modality == "image":
|
||||
vision_embeddings = self._process_image_input(multimodal_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
vision_embeddings = self._postprocess_image_embeds_evs(
|
||||
vision_embeddings, multimodal_input
|
||||
)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "video":
|
||||
video_embeddings = self._process_video_input(multimodal_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
video_embeddings = self._postprocess_video_embeds_evs(
|
||||
video_embeddings, multimodal_input
|
||||
)
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -1184,6 +1390,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
image_embeds = self._postprocess_image_embeds_evs(
|
||||
image_embeds, image_input
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
@ -1193,6 +1403,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
video_embeds = self._postprocess_video_embeds_evs(
|
||||
video_embeds, video_input
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
|
||||
273
vllm/multimodal/evs.py
Normal file
273
vllm/multimodal/evs.py
Normal file
@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int, q: float) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
Method ensures that we retain all the tokens from the first frame
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
video_size_thw: The size of the video in the format of (T, H, W).
|
||||
spatial_merge_size: The size of the spatial merge.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
evs_num_tokens = int(T * min_num_tokens * (1 - q))
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the retention mask for input video embeddings.
|
||||
|
||||
Args:
|
||||
video_embeds (`torch.Tensor`): The input video embeddings
|
||||
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
||||
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
||||
The temporal, height and width of video.
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
q: (`float`): Pruning rate factor [0,1)
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = video_size_thw
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
T,
|
||||
H // spatial_merge_size,
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
|
||||
video_embeds[:-1, ...],
|
||||
dim=-1)
|
||||
dissimilarity = 1 - similarity
|
||||
|
||||
# Always ensure we include all tokens from the first frame
|
||||
dissimilarity = torch.cat(
|
||||
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
|
||||
dim=0)
|
||||
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat,
|
||||
dim=-1,
|
||||
descending=True,
|
||||
stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(video_size_thw,
|
||||
spatial_merge_size, q)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
||||
retention_mask[topk_indices] = True
|
||||
retention_mask = retention_mask.reshape(dissimilarity.size())
|
||||
|
||||
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
||||
return mask
|
||||
|
||||
|
||||
def compute_mrope_for_media(
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
tokens_per_second: float = 1.0,
|
||||
video_second_per_grid: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the mrope for video embeddings based on the grid dimensions.
|
||||
Computed mrope positions match original qwen 2.5 implementation,
|
||||
but positions are built for media being the first element in sequence.
|
||||
|
||||
Args:
|
||||
video_size_thw: Media size (num frames, rows, cols)
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
tokens_per_second: Number of tokens per second.
|
||||
video_second_per_grid: Number of seconds per video.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(T * H * W, 4)` where last dimension
|
||||
represents mrope positions [0:3), while the last channel
|
||||
contains value of llm_grid_w repeated for all positions.
|
||||
"""
|
||||
llm_grid_t = video_size_thw[0]
|
||||
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
||||
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
||||
|
||||
t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).mul(
|
||||
tokens_per_second * video_second_per_grid)).long().flatten())
|
||||
h_index = (torch.arange(llm_grid_h).view(1, -1,
|
||||
1).expand(llm_grid_t, -1,
|
||||
llm_grid_w).flatten())
|
||||
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten())
|
||||
llm_grid_w = (torch.tensor([llm_grid_w
|
||||
]).view(1, 1,
|
||||
1).expand(llm_grid_t, llm_grid_h,
|
||||
llm_grid_w).flatten())
|
||||
|
||||
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
||||
return positions
|
||||
|
||||
|
||||
def recompute_mrope_positions(
|
||||
input_ids: torch.LongTensor,
|
||||
multimodal_positions: list[torch.Tensor],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
vision_start_token_id: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
) -> tuple[torch.LongTensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions.
|
||||
Original mrope_positions are computed incorrectly, so once we prune media
|
||||
tokens we should reflect this in the mrope positions for the LLM.
|
||||
|
||||
This method supports chunked prefill approach where
|
||||
multimodal_embeddings are passed to LLM in chunks, so input
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
Method works as follows: it recomputes mrope_positions starting from the
|
||||
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
||||
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
||||
|
||||
It also handles case when multimodal_embeddings is partial
|
||||
(e.g. one media is split into two prefill stages)
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positsions for each media.
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
image_token_id: Image token id
|
||||
video_token_id: Video token id
|
||||
|
||||
Returns:
|
||||
Tuple of (mrope_positions, mrope_position_delta).
|
||||
"""
|
||||
|
||||
# Tensors
|
||||
positions: torch.LongTensor = typing.cast(
|
||||
torch.LongTensor, mrope_positions.clone()) # (3, N)
|
||||
N = input_ids.numel()
|
||||
|
||||
image_mask = input_ids.eq(image_token_id)
|
||||
video_mask = input_ids.eq(video_token_id)
|
||||
media_mask = image_mask | video_mask
|
||||
text_mask = ~media_mask
|
||||
|
||||
# Early exit: no media in this chunk
|
||||
if len(multimodal_positions) == 0:
|
||||
delta = (int((positions.max().item() + 1) -
|
||||
N) if positions.numel() else -N)
|
||||
return positions, delta
|
||||
|
||||
total_mm_tokens = torch.count_nonzero(media_mask)
|
||||
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
||||
|
||||
# Early exit: we've updated positions for all media tokens
|
||||
# (and consequently - for all remaining text tokens)
|
||||
if seen_mm_tokens == total_mm_tokens:
|
||||
delta = (int((positions.max().item() + 1) -
|
||||
N) if positions.numel() else -N)
|
||||
return positions, delta
|
||||
|
||||
vision_start_indices = (input_ids == vision_start_token_id).nonzero(
|
||||
as_tuple=True)[0]
|
||||
|
||||
for mm_pos in multimodal_positions:
|
||||
# Each mm_pos can be a complete embedding for single media
|
||||
# or it can be a part of a single media (due to chunked prefill)
|
||||
|
||||
# Cases to cover
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
seen_vision_start_indices = vision_start_indices[vision_start_indices <
|
||||
num_computed_tokens]
|
||||
|
||||
if len(seen_vision_start_indices):
|
||||
# If we have encountered some vision start indexes,
|
||||
# then we should check the condition:
|
||||
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
||||
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
||||
last_vision_start_token = seen_vision_start_indices[-1]
|
||||
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
||||
media_mask[:last_vision_start_token])
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start)
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (seen_mm_tokens -
|
||||
seem_mm_tokens_before_last_vision_start)
|
||||
global_mm_start = last_vision_start_token
|
||||
else:
|
||||
# We have completed previous mm_embedding part and
|
||||
# ready to start a new one
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens][0]
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
else:
|
||||
# If there were no vision start indexes so far,
|
||||
# let's find first vision start index
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens][0]
|
||||
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
positions[:, local_end:N] = text_pos_sum + offset - 1
|
||||
|
||||
# Include distance to the next vision start token
|
||||
num_computed_tokens += mm_pos.shape[1]
|
||||
|
||||
mrope_positions_delta = (positions.max() + 1 - N).item()
|
||||
return positions, mrope_positions_delta
|
||||
@ -40,11 +40,15 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mrope,
|
||||
supports_multimodal_pruning,
|
||||
supports_transcription)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -206,7 +210,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
||||
self.is_multimodal_raw_input_only_model = (
|
||||
model_config.is_multimodal_raw_input_only_model)
|
||||
|
||||
# This will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
@ -1530,29 +1535,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_outputs = []
|
||||
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||
# in case feature_size is fixed across all multimodal items.
|
||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||
# depending on the input multimodal items.
|
||||
curr_group_outputs = model.get_multimodal_embeddings(
|
||||
**mm_kwargs_group)
|
||||
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
||||
# processing multimodal data.This solves the issue with scheduler
|
||||
# putting too many video samples into a single batch. Scheduler
|
||||
# uses pruned vision tokens count to compare it versus compute
|
||||
# budget which is incorrect (Either input media size or non-pruned
|
||||
# output vision tokens count should be considered)
|
||||
curr_group_outputs = []
|
||||
|
||||
if self.is_multimodal_pruning_enabled and modality == "video":
|
||||
micro_batch_size = 1
|
||||
for i in range(0, num_items, micro_batch_size):
|
||||
micro_batch_mm_inputs = dict(
|
||||
(k, v[i:i + micro_batch_size])
|
||||
for k, v in mm_kwargs_group.items())
|
||||
|
||||
micro_batch_outputs = model.get_multimodal_embeddings(
|
||||
**micro_batch_mm_inputs)
|
||||
|
||||
curr_group_outputs.extend(micro_batch_outputs)
|
||||
else:
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||
# in case feature_size is fixed across all multimodal items.
|
||||
# 2. A list or tuple (length: num_items) of tensors,
|
||||
# each of shape (feature_size, hidden_size) in case the feature
|
||||
# size is dynamic depending on the input multimodal items.
|
||||
curr_group_outputs = model.get_multimodal_embeddings(
|
||||
**mm_kwargs_group)
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
expected_num_items=num_items,
|
||||
)
|
||||
|
||||
for output in curr_group_outputs:
|
||||
encoder_outputs.append(output)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||
@ -1566,8 +1589,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
shift_computed_tokens: int = 0,
|
||||
) -> list[torch.Tensor]:
|
||||
should_sync_mrope_positions = False
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mm_embeds_req: list[torch.Tensor] = []
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
@ -1609,7 +1635,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
should_sync_mrope_positions = True
|
||||
mm_embeds_req, new_mrope_positions, new_delta = (
|
||||
self.model.recompute_mrope_positions(
|
||||
input_ids=req_state.prompt_token_ids,
|
||||
multimodal_embeddings=mm_embeds_req,
|
||||
mrope_positions=req_state.mrope_positions,
|
||||
num_computed_tokens=req_state.num_computed_tokens,
|
||||
))
|
||||
assert req_state.mrope_positions is not None
|
||||
req_state.mrope_positions.copy_(new_mrope_positions)
|
||||
req_state.mrope_position_delta = new_delta
|
||||
|
||||
mm_embeds.extend(mm_embeds_req)
|
||||
|
||||
if should_sync_mrope_positions:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
self.mrope_positions.copy_to_gpu(
|
||||
scheduler_output.total_num_scheduled_tokens)
|
||||
|
||||
return mm_embeds
|
||||
|
||||
def _extract_encoder_inputs(
|
||||
@ -2589,6 +2636,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
time_after_load - time_before_load)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
|
||||
self.is_multimodal_pruning_enabled = (supports_multimodal_pruning(
|
||||
self.model) and self.model_config.multimodal_config.
|
||||
is_multimodal_pruning_enabled())
|
||||
|
||||
if is_mixture_of_experts(
|
||||
self.model) and self.parallel_config.enable_eplb:
|
||||
logger.info("EPLB is enabled for model %s.",
|
||||
@ -2843,7 +2894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Args:
|
||||
num_tokens: Number of tokens to run the dummy forward pass.
|
||||
cudagraph_runtime_mode: used to control the behavior.
|
||||
- if not set will determine the cudagraph mode based on using
|
||||
- if not set will determine the cudagraph mode based on using
|
||||
the self.cudagraph_dispatcher.
|
||||
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
|
||||
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user