mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:45:31 +08:00
[Feature]Add EVS (Efficient Video Sampling) Support for Qwen3-VL (#29752)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Co-authored-by: deitxfge <huhaibo1990@126.com>
This commit is contained in:
parent
5ccf0efa84
commit
ae88aada38
@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.evs import (
|
||||||
|
compute_mrope_for_media,
|
||||||
|
compute_retained_tokens_count,
|
||||||
|
compute_retention_mask,
|
||||||
|
recompute_mrope_positions,
|
||||||
|
)
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
MultiModalFeatureSpec,
|
MultiModalFeatureSpec,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalKwargsItem,
|
MultiModalKwargsItem,
|
||||||
MultiModalKwargsItems,
|
MultiModalKwargsItems,
|
||||||
|
PlaceholderRange,
|
||||||
VideoItem,
|
VideoItem,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
|
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
|
||||||
@ -92,6 +99,7 @@ from .interfaces import (
|
|||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
|
SupportsMultiModalPruning,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
_require_is_multimodal,
|
_require_is_multimodal,
|
||||||
)
|
)
|
||||||
@ -1043,13 +1051,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
|
|||||||
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
|
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
|
||||||
for curr_time in timestamps
|
for curr_time in timestamps
|
||||||
]
|
]
|
||||||
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
|
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
|
||||||
|
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
|
||||||
|
|
||||||
|
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
|
||||||
|
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||||
|
total_retained = compute_retained_tokens_count(
|
||||||
|
tokens_per_frame,
|
||||||
|
len(frames_idx_token),
|
||||||
|
video_pruning_rate,
|
||||||
|
)
|
||||||
|
if len(frames_idx_token) == 0:
|
||||||
|
per_frame_token_counts = []
|
||||||
|
elif len(frames_idx_token) == 1:
|
||||||
|
per_frame_token_counts = [tokens_per_frame]
|
||||||
|
else:
|
||||||
|
first_frame_tokens = tokens_per_frame
|
||||||
|
remaining_tokens = max(total_retained - first_frame_tokens, 0)
|
||||||
|
base = remaining_tokens // (len(frames_idx_token) - 1)
|
||||||
|
remainder = remaining_tokens % (len(frames_idx_token) - 1)
|
||||||
|
per_frame_token_counts = [first_frame_tokens]
|
||||||
|
for frame_idx in range(1, len(frames_idx_token)):
|
||||||
|
extra = base + (1 if (frame_idx - 1) < remainder else 0)
|
||||||
|
per_frame_token_counts.append(extra)
|
||||||
|
|
||||||
placeholder = []
|
placeholder = []
|
||||||
for frame_idx in frames_idx_token:
|
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
|
||||||
placeholder.extend(frame_idx)
|
placeholder.extend(timestamp_tokens)
|
||||||
|
tokens_this_frame = per_frame_token_counts[
|
||||||
|
frame_idx if frame_idx < len(per_frame_token_counts) else -1
|
||||||
|
]
|
||||||
placeholder.extend(
|
placeholder.extend(
|
||||||
[vision_start_token_id]
|
[vision_start_token_id]
|
||||||
+ [video_token_id] * num_tokens_per_frame
|
+ [video_token_id] * tokens_this_frame
|
||||||
+ [vision_end_token_id]
|
+ [vision_end_token_id]
|
||||||
)
|
)
|
||||||
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
|
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
|
||||||
@ -1190,6 +1224,7 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
|
SupportsMultiModalPruning,
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
@ -1232,6 +1267,11 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
|
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||||
|
self.is_multimodal_pruning_enabled = (
|
||||||
|
multimodal_config.is_multimodal_pruning_enabled()
|
||||||
|
)
|
||||||
|
|
||||||
if not multimodal_config.get_limit_per_prompt(
|
if not multimodal_config.get_limit_per_prompt(
|
||||||
"image"
|
"image"
|
||||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||||
@ -1418,6 +1458,109 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
return video_embeds.split(sizes)
|
return video_embeds.split(sizes)
|
||||||
|
|
||||||
|
def _postprocess_image_embeds_evs(
|
||||||
|
self,
|
||||||
|
image_embeds_split: tuple[torch.Tensor, ...],
|
||||||
|
image_input: Qwen2_5_VLImageInputs,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Append mrope positions for each for images.
|
||||||
|
This is necessary to recover correct mrope
|
||||||
|
positions after video pruning
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_embeds_split: Tuple of image embeddings for
|
||||||
|
each image item.
|
||||||
|
image_input: Image input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of image embeddings for each image item.
|
||||||
|
Resulting embeddings will have extra 4 channels for
|
||||||
|
computed mrope positions.
|
||||||
|
"""
|
||||||
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
grid_thw = image_input["image_grid_thw"]
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
image_embeds_out = []
|
||||||
|
for emb, size in zip(image_embeds_split, grid_thw_list):
|
||||||
|
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
|
||||||
|
emb = torch.cat([emb, positions], dim=1)
|
||||||
|
image_embeds_out.append(emb)
|
||||||
|
image_embeds_split = image_embeds_out
|
||||||
|
return tuple(image_embeds_split)
|
||||||
|
|
||||||
|
def _postprocess_video_embeds_evs(
|
||||||
|
self,
|
||||||
|
video_embeds_split: tuple[torch.Tensor, ...],
|
||||||
|
video_input: Qwen2_5_VLVideoInputs,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Prunes video embeddings via Efficient Video Sampling (EVS)
|
||||||
|
and then appends mrope positions for each retained embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_embeds_split: Tuple of video embeddings for each video item.
|
||||||
|
video_input: Video input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of video embeddings for each video item.
|
||||||
|
Resulting embeddings will have extra 4 channels for
|
||||||
|
computed mrope positions.
|
||||||
|
"""
|
||||||
|
grid_thw = video_input["video_grid_thw"]
|
||||||
|
assert grid_thw.ndim == 2
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
|
||||||
|
# Cast to long to match the original code
|
||||||
|
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
|
||||||
|
second_per_grid_ts = video_input.get("second_per_grid_ts")
|
||||||
|
if second_per_grid_ts is None:
|
||||||
|
# For Qwen3-VL, second_per_grid_ts might not be available
|
||||||
|
# Use default value of 1.0 for each video
|
||||||
|
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
|
||||||
|
else:
|
||||||
|
second_per_grid_ts = second_per_grid_ts.long()
|
||||||
|
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
|
||||||
|
|
||||||
|
video_embeds_out = []
|
||||||
|
for emb, size, video_second_per_grid_t in zip(
|
||||||
|
video_embeds_split, grid_thw_list, second_per_grid_ts
|
||||||
|
):
|
||||||
|
# For each video, we compute retention mask using EVS
|
||||||
|
retention_mask = compute_retention_mask(
|
||||||
|
emb,
|
||||||
|
size,
|
||||||
|
spatial_merge_size=self.visual.spatial_merge_size,
|
||||||
|
q=self.video_pruning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Debug logging for EVS pruning
|
||||||
|
logger.debug(
|
||||||
|
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
|
||||||
|
"pruning_rate=%.2f, reduction=%.1f%%)",
|
||||||
|
emb.shape[0],
|
||||||
|
retention_mask.sum().item(),
|
||||||
|
size[0],
|
||||||
|
size[1],
|
||||||
|
size[2],
|
||||||
|
self.video_pruning_rate,
|
||||||
|
(1 - retention_mask.float().mean().item()) * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = compute_mrope_for_media(
|
||||||
|
size,
|
||||||
|
merge_size,
|
||||||
|
tokens_per_second=tokens_per_second,
|
||||||
|
video_second_per_grid=video_second_per_grid_t.item(),
|
||||||
|
).to(emb.device)
|
||||||
|
|
||||||
|
emb = emb[retention_mask]
|
||||||
|
positions = positions[retention_mask]
|
||||||
|
emb = torch.cat([emb, positions], dim=1)
|
||||||
|
video_embeds_out.append(emb)
|
||||||
|
return tuple(video_embeds_out)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
mm_input_by_modality = {}
|
mm_input_by_modality = {}
|
||||||
for input_key in kwargs:
|
for input_key in kwargs:
|
||||||
@ -1440,6 +1583,20 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
def iter_mm_grid_hw(
|
def iter_mm_grid_hw(
|
||||||
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
|
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
|
||||||
) -> Iterator[tuple[int, int, int]]:
|
) -> Iterator[tuple[int, int, int]]:
|
||||||
|
"""
|
||||||
|
Iterate over multimodal features and yield grid information.
|
||||||
|
|
||||||
|
For videos with EVS (Efficient Video Sampling) enabled, this function
|
||||||
|
computes the offset based on the pruned token count rather than relying
|
||||||
|
on input_tokens.index(), which would fail when tokens are pruned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tokens: List of token IDs in the prompt
|
||||||
|
mm_features: List of multimodal feature specifications
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (offset, grid_h, grid_w) for each frame/image
|
||||||
|
"""
|
||||||
video_token_id = self.config.video_token_id
|
video_token_id = self.config.video_token_id
|
||||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
||||||
@ -1452,42 +1609,289 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
||||||
llm_grid_h = h // spatial_merge_size
|
llm_grid_h = h // spatial_merge_size
|
||||||
llm_grid_w = w // spatial_merge_size
|
llm_grid_w = w // spatial_merge_size
|
||||||
for _ in range(t):
|
|
||||||
offset = input_tokens.index(video_token_id, offset)
|
# Check if EVS (Efficient Video Sampling) is enabled
|
||||||
yield offset, llm_grid_h, llm_grid_w
|
is_evs_enabled = (
|
||||||
offset += llm_grid_h * llm_grid_w
|
hasattr(self, "video_pruning_rate")
|
||||||
|
and self.video_pruning_rate is not None
|
||||||
|
and self.video_pruning_rate > 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_evs_enabled:
|
||||||
|
frame_offsets = self._extract_frame_offsets_from_mask(
|
||||||
|
mm_feature.mm_position, t
|
||||||
|
)
|
||||||
|
if frame_offsets is not None:
|
||||||
|
for rel_offset in frame_offsets:
|
||||||
|
yield offset + rel_offset, llm_grid_h, llm_grid_w
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If EVS is enabled but mask is missing, this indicates a bug
|
||||||
|
# in the prompt processing pipeline. The is_embed mask should
|
||||||
|
# always be present when video_pruning_rate > 0.
|
||||||
|
raise RuntimeError(
|
||||||
|
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
|
||||||
|
"but is_embed mask is missing from mm_position. "
|
||||||
|
"This indicates a bug in prompt processing."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-EVS mode: Use original logic with input_tokens.index()
|
||||||
|
for _ in range(t):
|
||||||
|
offset = input_tokens.index(video_token_id, offset)
|
||||||
|
yield offset, llm_grid_h, llm_grid_w
|
||||||
|
offset += llm_grid_h * llm_grid_w
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
|
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
|
||||||
|
|
||||||
|
def _get_evs_mask_segments(
|
||||||
|
self, mm_position: PlaceholderRange, expected_frames: int
|
||||||
|
) -> list[torch.Tensor] | None:
|
||||||
|
"""Extract contiguous segments from EVS is_embed mask.
|
||||||
|
|
||||||
|
The EVS (Efficient Video Sampling) mask marks which placeholder
|
||||||
|
positions should be filled with video embeddings. This method splits
|
||||||
|
the mask into contiguous segments, where each segment represents one
|
||||||
|
retained frame.
|
||||||
|
|
||||||
|
This is a pure function - it does not modify any state and always
|
||||||
|
returns the same output for the same input (idempotent).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_position: MultiModal position containing the is_embed mask
|
||||||
|
expected_frames: Expected number of frame segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tensors, each containing indices for one frame segment,
|
||||||
|
or None if EVS is not enabled or validation fails.
|
||||||
|
"""
|
||||||
|
is_embed_mask = getattr(mm_position, "is_embed", None)
|
||||||
|
if is_embed_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find all True positions in the mask
|
||||||
|
mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
|
||||||
|
true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
|
||||||
|
if true_indices.numel() == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Split into contiguous segments (where diff > 1 indicates a gap)
|
||||||
|
if true_indices.numel() == 1:
|
||||||
|
segments = [true_indices]
|
||||||
|
else:
|
||||||
|
diffs = torch.diff(true_indices)
|
||||||
|
split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
|
||||||
|
if split_points.numel() == 0:
|
||||||
|
segments = [true_indices]
|
||||||
|
else:
|
||||||
|
segments = torch.tensor_split(
|
||||||
|
true_indices, split_points.add(1).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate segment count matches expected frames
|
||||||
|
if len(segments) < expected_frames:
|
||||||
|
logger.debug(
|
||||||
|
"EVS mask segments (%d) do not match expected frames (%d)",
|
||||||
|
len(segments),
|
||||||
|
expected_frames,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return segments[:expected_frames]
|
||||||
|
|
||||||
|
def _extract_frame_offsets_from_mask(
|
||||||
|
self, mm_position: PlaceholderRange, expected_frames: int
|
||||||
|
) -> list[int] | None:
|
||||||
|
"""Return relative offsets for each EVS-retained frame.
|
||||||
|
|
||||||
|
The prompt processor stores a boolean mask inside ``mm_position`` that
|
||||||
|
marks which placeholder locations should be populated with video
|
||||||
|
embeddings. By splitting that mask into contiguous runs we can recover
|
||||||
|
the start of every retained frame without probing ``input_tokens``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_position: MultiModal position containing the is_embed mask
|
||||||
|
expected_frames: Expected number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of starting offsets (relative to mm_position) for each frame,
|
||||||
|
or None if EVS is not enabled.
|
||||||
|
"""
|
||||||
|
segments = self._get_evs_mask_segments(mm_position, expected_frames)
|
||||||
|
if segments is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return [int(segment[0].item()) for segment in segments]
|
||||||
|
|
||||||
|
def _get_actual_frame_token_counts(
|
||||||
|
self, mm_position: PlaceholderRange, expected_frames: int
|
||||||
|
) -> list[int] | None:
|
||||||
|
"""Return actual token count for each EVS-retained frame.
|
||||||
|
|
||||||
|
This function calculates the actual number of tokens per frame by
|
||||||
|
analyzing the is_embed mask, accounting for EVS pruning. Each frame
|
||||||
|
may have a different token count due to content-aware pruning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_position: MultiModal position containing the is_embed mask
|
||||||
|
expected_frames: Expected number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token counts for each frame, or None if EVS is not enabled.
|
||||||
|
"""
|
||||||
|
segments = self._get_evs_mask_segments(mm_position, expected_frames)
|
||||||
|
if segments is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return [len(seg) for seg in segments]
|
||||||
|
|
||||||
|
def recompute_mrope_positions(
|
||||||
|
self,
|
||||||
|
input_ids: list[int],
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...],
|
||||||
|
mrope_positions: torch.LongTensor,
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Update part of input mrope positions (starting with
|
||||||
|
num_computed_tokens index). Original mrope_positions are computed
|
||||||
|
for unpruned sequence and becomes incorrect once pruning occurs,
|
||||||
|
so once we prune media tokens we should reflect this in the
|
||||||
|
mrope_positions before we feed it to LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: (N,) All input tokens of the prompt (Containing
|
||||||
|
entire sequence).
|
||||||
|
multimodal_embeddings: Tuple of multimodal embeddings.
|
||||||
|
mrope_positions: Existing mrope positions (3, N) for entire
|
||||||
|
sequence
|
||||||
|
num_computed_tokens: A number of computed tokens so far.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (multimodal_embeddings, mrope_positions,
|
||||||
|
mrope_position_delta).
|
||||||
|
"""
|
||||||
|
image_token_id = self.config.image_token_id
|
||||||
|
video_token_id = self.config.video_token_id
|
||||||
|
vision_start_token_id = self.config.vision_start_token_id
|
||||||
|
|
||||||
|
# Device
|
||||||
|
device = (
|
||||||
|
multimodal_embeddings[0].device
|
||||||
|
if len(multimodal_embeddings)
|
||||||
|
else mrope_positions.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tensors
|
||||||
|
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
|
||||||
|
mm_embeddings_pos = [
|
||||||
|
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
|
||||||
|
]
|
||||||
|
|
||||||
|
positions, mrope_positions_delta = recompute_mrope_positions(
|
||||||
|
input_ids_t,
|
||||||
|
mm_embeddings_pos,
|
||||||
|
mrope_positions,
|
||||||
|
num_computed_tokens,
|
||||||
|
vision_start_token_id,
|
||||||
|
image_token_id,
|
||||||
|
video_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(mm_embeddings_out), positions, mrope_positions_delta
|
||||||
|
|
||||||
def get_mrope_input_positions(
|
def get_mrope_input_positions(
|
||||||
self,
|
self,
|
||||||
input_tokens: list[int],
|
input_tokens: list[int],
|
||||||
mm_features: list[MultiModalFeatureSpec],
|
mm_features: list[MultiModalFeatureSpec],
|
||||||
) -> tuple[torch.Tensor, int]:
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
# Pre-collect actual frame token counts for EVS mode
|
||||||
|
frame_token_counts_map = {}
|
||||||
|
for mm_feature in mm_features:
|
||||||
|
if mm_feature.modality == "video":
|
||||||
|
is_evs_enabled = (
|
||||||
|
hasattr(self, "video_pruning_rate")
|
||||||
|
and self.video_pruning_rate is not None
|
||||||
|
and self.video_pruning_rate > 0.0
|
||||||
|
)
|
||||||
|
if is_evs_enabled:
|
||||||
|
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
|
||||||
|
token_counts = self._get_actual_frame_token_counts(
|
||||||
|
mm_feature.mm_position, t
|
||||||
|
)
|
||||||
|
assert token_counts is not None, (
|
||||||
|
"EVS enabled but failed to extract frame token counts "
|
||||||
|
"from is_embed mask"
|
||||||
|
)
|
||||||
|
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
|
||||||
|
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
st = 0
|
st = 0
|
||||||
|
frame_counts_idx = {}
|
||||||
|
|
||||||
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
|
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
|
||||||
input_tokens, mm_features
|
input_tokens, mm_features
|
||||||
):
|
):
|
||||||
text_len = offset - st
|
text_len = offset - st
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list.append(
|
|
||||||
|
# Determine actual token count for this frame
|
||||||
|
base_offset = None
|
||||||
|
for feat_offset in frame_token_counts_map:
|
||||||
|
if offset >= feat_offset:
|
||||||
|
base_offset = feat_offset
|
||||||
|
|
||||||
|
if base_offset is not None:
|
||||||
|
# EVS mode: use actual token count from is_embed mask
|
||||||
|
assert base_offset in frame_token_counts_map, (
|
||||||
|
f"Found base_offset {base_offset} but not in frame_token_counts_map"
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_offset not in frame_counts_idx:
|
||||||
|
frame_counts_idx[base_offset] = 0
|
||||||
|
|
||||||
|
counts = frame_token_counts_map[base_offset]
|
||||||
|
idx = frame_counts_idx[base_offset]
|
||||||
|
|
||||||
|
assert idx < len(counts), (
|
||||||
|
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_frame_tokens = counts[idx]
|
||||||
|
frame_counts_idx[base_offset] += 1
|
||||||
|
else:
|
||||||
|
# Non-EVS mode (or image): use theoretical grid size
|
||||||
|
actual_frame_tokens = llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
# Add text segment
|
||||||
|
text_positions = (
|
||||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||||
)
|
)
|
||||||
|
llm_pos_ids_list.append(text_positions)
|
||||||
|
st_idx += text_len
|
||||||
|
|
||||||
|
# Add frame segment with actual token count (not theoretical)
|
||||||
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
|
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
|
||||||
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
|
# Only take the first actual_frame_tokens positions
|
||||||
st = offset + llm_grid_h * llm_grid_w
|
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
|
||||||
|
llm_pos_ids_list.append(frame_positions)
|
||||||
|
|
||||||
|
# Update st using actual token count
|
||||||
|
st = offset + actual_frame_tokens
|
||||||
|
|
||||||
|
# Handle final text segment
|
||||||
if st < len(input_tokens):
|
if st < len(input_tokens):
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
text_len = len(input_tokens) - st
|
text_len = len(input_tokens) - st
|
||||||
llm_pos_ids_list.append(
|
final_text_positions = (
|
||||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||||
)
|
)
|
||||||
|
llm_pos_ids_list.append(final_text_positions)
|
||||||
|
|
||||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
|
||||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
@ -1508,9 +1912,17 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
multimodal_input = mm_input_by_modality[modality]
|
multimodal_input = mm_input_by_modality[modality]
|
||||||
if modality == "image":
|
if modality == "image":
|
||||||
image_embeddings = self._process_image_input(multimodal_input)
|
image_embeddings = self._process_image_input(multimodal_input)
|
||||||
|
if self.is_multimodal_pruning_enabled:
|
||||||
|
image_embeddings = self._postprocess_image_embeds_evs(
|
||||||
|
image_embeddings, multimodal_input
|
||||||
|
)
|
||||||
multimodal_embeddings += tuple(image_embeddings)
|
multimodal_embeddings += tuple(image_embeddings)
|
||||||
if modality == "video":
|
if modality == "video":
|
||||||
video_embeddings = self._process_video_input(multimodal_input)
|
video_embeddings = self._process_video_input(multimodal_input)
|
||||||
|
if self.is_multimodal_pruning_enabled:
|
||||||
|
video_embeddings = self._postprocess_video_embeds_evs(
|
||||||
|
video_embeddings, multimodal_input
|
||||||
|
)
|
||||||
multimodal_embeddings += tuple(video_embeddings)
|
multimodal_embeddings += tuple(video_embeddings)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user