[Feature]Add EVS (Efficient Video Sampling) Support for Qwen3-VL (#29752)

Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: deitxfge <huhaibo1990@126.com>
This commit is contained in:
ZiTian Zhao 2025-12-14 21:24:56 +08:00 committed by GitHub
parent 5ccf0efa84
commit ae88aada38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
compute_retained_tokens_count,
compute_retention_mask,
recompute_mrope_positions,
)
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
PlaceholderRange,
VideoItem, VideoItem,
) )
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
@ -92,6 +99,7 @@ from .interfaces import (
SupportsLoRA, SupportsLoRA,
SupportsMRoPE, SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsMultiModalPruning,
SupportsPP, SupportsPP,
_require_is_multimodal, _require_is_multimodal,
) )
@ -1043,13 +1051,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
for curr_time in timestamps for curr_time in timestamps
] ]
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
total_retained = compute_retained_tokens_count(
tokens_per_frame,
len(frames_idx_token),
video_pruning_rate,
)
if len(frames_idx_token) == 0:
per_frame_token_counts = []
elif len(frames_idx_token) == 1:
per_frame_token_counts = [tokens_per_frame]
else:
first_frame_tokens = tokens_per_frame
remaining_tokens = max(total_retained - first_frame_tokens, 0)
base = remaining_tokens // (len(frames_idx_token) - 1)
remainder = remaining_tokens % (len(frames_idx_token) - 1)
per_frame_token_counts = [first_frame_tokens]
for frame_idx in range(1, len(frames_idx_token)):
extra = base + (1 if (frame_idx - 1) < remainder else 0)
per_frame_token_counts.append(extra)
placeholder = [] placeholder = []
for frame_idx in frames_idx_token: for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
placeholder.extend(frame_idx) placeholder.extend(timestamp_tokens)
tokens_this_frame = per_frame_token_counts[
frame_idx if frame_idx < len(per_frame_token_counts) else -1
]
placeholder.extend( placeholder.extend(
[vision_start_token_id] [vision_start_token_id]
+ [video_token_id] * num_tokens_per_frame + [video_token_id] * tokens_this_frame
+ [vision_end_token_id] + [vision_end_token_id]
) )
return PromptUpdateDetails.select_token_id(placeholder, video_token_id) return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
@ -1190,6 +1224,7 @@ class Qwen3VLForConditionalGeneration(
SupportsPP, SupportsPP,
SupportsMRoPE, SupportsMRoPE,
SupportsEagle3, SupportsEagle3,
SupportsMultiModalPruning,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@ -1232,6 +1267,11 @@ class Qwen3VLForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt( if not multimodal_config.get_limit_per_prompt(
"image" "image"
) and not multimodal_config.get_limit_per_prompt("video"): ) and not multimodal_config.get_limit_per_prompt("video"):
@ -1418,6 +1458,109 @@ class Qwen3VLForConditionalGeneration(
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes) return video_embeds.split(sizes)
def _postprocess_image_embeds_evs(
self,
image_embeds_split: tuple[torch.Tensor, ...],
image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
"""
Append mrope positions for each for images.
This is necessary to recover correct mrope
positions after video pruning
Args:
image_embeds_split: Tuple of image embeddings for
each image item.
image_input: Image input data.
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = image_embeds_out
return tuple(image_embeds_split)
def _postprocess_video_embeds_evs(
self,
video_embeds_split: tuple[torch.Tensor, ...],
video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
"""
Prunes video embeddings via Efficient Video Sampling (EVS)
and then appends mrope positions for each retained embeddings
Args:
video_embeds_split: Tuple of video embeddings for each video item.
video_input: Video input data.
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
second_per_grid_ts = video_input.get("second_per_grid_ts")
if second_per_grid_ts is None:
# For Qwen3-VL, second_per_grid_ts might not be available
# Use default value of 1.0 for each video
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
else:
second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
video_embeds_out = []
for emb, size, video_second_per_grid_t in zip(
video_embeds_split, grid_thw_list, second_per_grid_ts
):
# For each video, we compute retention mask using EVS
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Debug logging for EVS pruning
logger.debug(
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
"pruning_rate=%.2f, reduction=%.1f%%)",
emb.shape[0],
retention_mask.sum().item(),
size[0],
size[1],
size[2],
self.video_pruning_rate,
(1 - retention_mask.float().mean().item()) * 100,
)
positions = compute_mrope_for_media(
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
emb = emb[retention_mask]
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {} mm_input_by_modality = {}
for input_key in kwargs: for input_key in kwargs:
@ -1440,6 +1583,20 @@ class Qwen3VLForConditionalGeneration(
def iter_mm_grid_hw( def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]: ) -> Iterator[tuple[int, int, int]]:
"""
Iterate over multimodal features and yield grid information.
For videos with EVS (Efficient Video Sampling) enabled, this function
computes the offset based on the pruned token count rather than relying
on input_tokens.index(), which would fail when tokens are pruned.
Args:
input_tokens: List of token IDs in the prompt
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_h, grid_w) for each frame/image
"""
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
@ -1452,42 +1609,289 @@ class Qwen3VLForConditionalGeneration(
t, h, w = mm_feature.data["video_grid_thw"].data.tolist() t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size llm_grid_w = w // spatial_merge_size
for _ in range(t):
offset = input_tokens.index(video_token_id, offset) # Check if EVS (Efficient Video Sampling) is enabled
yield offset, llm_grid_h, llm_grid_w is_evs_enabled = (
offset += llm_grid_h * llm_grid_w hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
frame_offsets = self._extract_frame_offsets_from_mask(
mm_feature.mm_position, t
)
if frame_offsets is not None:
for rel_offset in frame_offsets:
yield offset + rel_offset, llm_grid_h, llm_grid_w
continue
# If EVS is enabled but mask is missing, this indicates a bug
# in the prompt processing pipeline. The is_embed mask should
# always be present when video_pruning_rate > 0.
raise RuntimeError(
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
"but is_embed mask is missing from mm_position. "
"This indicates a bug in prompt processing."
)
else:
# Non-EVS mode: Use original logic with input_tokens.index()
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
else: else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}") raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def _get_evs_mask_segments(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[torch.Tensor] | None:
"""Extract contiguous segments from EVS is_embed mask.
The EVS (Efficient Video Sampling) mask marks which placeholder
positions should be filled with video embeddings. This method splits
the mask into contiguous segments, where each segment represents one
retained frame.
This is a pure function - it does not modify any state and always
returns the same output for the same input (idempotent).
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frame segments
Returns:
List of tensors, each containing indices for one frame segment,
or None if EVS is not enabled or validation fails.
"""
is_embed_mask = getattr(mm_position, "is_embed", None)
if is_embed_mask is None:
return None
# Find all True positions in the mask
mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
if true_indices.numel() == 0:
return None
# Split into contiguous segments (where diff > 1 indicates a gap)
if true_indices.numel() == 1:
segments = [true_indices]
else:
diffs = torch.diff(true_indices)
split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
if split_points.numel() == 0:
segments = [true_indices]
else:
segments = torch.tensor_split(
true_indices, split_points.add(1).tolist()
)
# Validate segment count matches expected frames
if len(segments) < expected_frames:
logger.debug(
"EVS mask segments (%d) do not match expected frames (%d)",
len(segments),
expected_frames,
)
return None
return segments[:expected_frames]
def _extract_frame_offsets_from_mask(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return relative offsets for each EVS-retained frame.
The prompt processor stores a boolean mask inside ``mm_position`` that
marks which placeholder locations should be populated with video
embeddings. By splitting that mask into contiguous runs we can recover
the start of every retained frame without probing ``input_tokens``.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of starting offsets (relative to mm_position) for each frame,
or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [int(segment[0].item()) for segment in segments]
def _get_actual_frame_token_counts(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return actual token count for each EVS-retained frame.
This function calculates the actual number of tokens per frame by
analyzing the is_embed mask, accounting for EVS pruning. Each frame
may have a different token count due to content-aware pruning.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of token counts for each frame, or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [len(seg) for seg in segments]
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt (Containing
entire sequence).
multimodal_embeddings: Tuple of multimodal embeddings.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
# Device
device = (
multimodal_embeddings[0].device
if len(multimodal_embeddings)
else mrope_positions.device
)
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
mm_embeddings_pos = [
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
]
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
mm_embeddings_pos,
mrope_positions,
num_computed_tokens,
vision_start_token_id,
image_token_id,
video_token_id,
)
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
# Pre-collect actual frame token counts for EVS mode
frame_token_counts_map = {}
for mm_feature in mm_features:
if mm_feature.modality == "video":
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
token_counts = self._get_actual_frame_token_counts(
mm_feature.mm_position, t
)
assert token_counts is not None, (
"EVS enabled but failed to extract frame token counts "
"from is_embed mask"
)
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
llm_pos_ids_list = [] llm_pos_ids_list = []
st = 0 st = 0
frame_counts_idx = {}
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features input_tokens, mm_features
): ):
text_len = offset - st text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
# Determine actual token count for this frame
base_offset = None
for feat_offset in frame_token_counts_map:
if offset >= feat_offset:
base_offset = feat_offset
if base_offset is not None:
# EVS mode: use actual token count from is_embed mask
assert base_offset in frame_token_counts_map, (
f"Found base_offset {base_offset} but not in frame_token_counts_map"
)
if base_offset not in frame_counts_idx:
frame_counts_idx[base_offset] = 0
counts = frame_token_counts_map[base_offset]
idx = frame_counts_idx[base_offset]
assert idx < len(counts), (
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
)
actual_frame_tokens = counts[idx]
frame_counts_idx[base_offset] += 1
else:
# Non-EVS mode (or image): use theoretical grid size
actual_frame_tokens = llm_grid_h * llm_grid_w
# Add text segment
text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_pos_ids_list.append(text_positions)
st_idx += text_len
# Add frame segment with actual token count (not theoretical)
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx) # Only take the first actual_frame_tokens positions
st = offset + llm_grid_h * llm_grid_w frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
llm_pos_ids_list.append(frame_positions)
# Update st using actual token count
st = offset + actual_frame_tokens
# Handle final text segment
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( final_text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_pos_ids_list.append(final_text_positions)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
@ -1508,9 +1912,17 @@ class Qwen3VLForConditionalGeneration(
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
image_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings += tuple(image_embeddings) multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings