mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:45:02 +08:00
[Feature]Add EVS (Efficient Video Sampling) Support for Qwen3-VL (#29752)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Co-authored-by: deitxfge <huhaibo1990@126.com>
This commit is contained in:
parent
5ccf0efa84
commit
ae88aada38
@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.evs import (
|
||||
compute_mrope_for_media,
|
||||
compute_retained_tokens_count,
|
||||
compute_retention_mask,
|
||||
recompute_mrope_positions,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
PlaceholderRange,
|
||||
VideoItem,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
|
||||
@ -92,6 +99,7 @@ from .interfaces import (
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsMultiModalPruning,
|
||||
SupportsPP,
|
||||
_require_is_multimodal,
|
||||
)
|
||||
@ -1043,13 +1051,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
|
||||
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
|
||||
for curr_time in timestamps
|
||||
]
|
||||
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
|
||||
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
|
||||
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
|
||||
|
||||
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
total_retained = compute_retained_tokens_count(
|
||||
tokens_per_frame,
|
||||
len(frames_idx_token),
|
||||
video_pruning_rate,
|
||||
)
|
||||
if len(frames_idx_token) == 0:
|
||||
per_frame_token_counts = []
|
||||
elif len(frames_idx_token) == 1:
|
||||
per_frame_token_counts = [tokens_per_frame]
|
||||
else:
|
||||
first_frame_tokens = tokens_per_frame
|
||||
remaining_tokens = max(total_retained - first_frame_tokens, 0)
|
||||
base = remaining_tokens // (len(frames_idx_token) - 1)
|
||||
remainder = remaining_tokens % (len(frames_idx_token) - 1)
|
||||
per_frame_token_counts = [first_frame_tokens]
|
||||
for frame_idx in range(1, len(frames_idx_token)):
|
||||
extra = base + (1 if (frame_idx - 1) < remainder else 0)
|
||||
per_frame_token_counts.append(extra)
|
||||
|
||||
placeholder = []
|
||||
for frame_idx in frames_idx_token:
|
||||
placeholder.extend(frame_idx)
|
||||
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
|
||||
placeholder.extend(timestamp_tokens)
|
||||
tokens_this_frame = per_frame_token_counts[
|
||||
frame_idx if frame_idx < len(per_frame_token_counts) else -1
|
||||
]
|
||||
placeholder.extend(
|
||||
[vision_start_token_id]
|
||||
+ [video_token_id] * num_tokens_per_frame
|
||||
+ [video_token_id] * tokens_this_frame
|
||||
+ [vision_end_token_id]
|
||||
)
|
||||
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
|
||||
@ -1190,6 +1224,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
SupportsEagle3,
|
||||
SupportsMultiModalPruning,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@ -1232,6 +1267,11 @@ class Qwen3VLForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
|
||||
if not multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
@ -1418,6 +1458,109 @@ class Qwen3VLForConditionalGeneration(
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
def _postprocess_image_embeds_evs(
|
||||
self,
|
||||
image_embeds_split: tuple[torch.Tensor, ...],
|
||||
image_input: Qwen2_5_VLImageInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Append mrope positions for each for images.
|
||||
This is necessary to recover correct mrope
|
||||
positions after video pruning
|
||||
|
||||
Args:
|
||||
image_embeds_split: Tuple of image embeddings for
|
||||
each image item.
|
||||
image_input: Image input data.
|
||||
|
||||
Returns:
|
||||
Tuple of image embeddings for each image item.
|
||||
Resulting embeddings will have extra 4 channels for
|
||||
computed mrope positions.
|
||||
"""
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
image_embeds_out = []
|
||||
for emb, size in zip(image_embeds_split, grid_thw_list):
|
||||
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
|
||||
emb = torch.cat([emb, positions], dim=1)
|
||||
image_embeds_out.append(emb)
|
||||
image_embeds_split = image_embeds_out
|
||||
return tuple(image_embeds_split)
|
||||
|
||||
def _postprocess_video_embeds_evs(
|
||||
self,
|
||||
video_embeds_split: tuple[torch.Tensor, ...],
|
||||
video_input: Qwen2_5_VLVideoInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Prunes video embeddings via Efficient Video Sampling (EVS)
|
||||
and then appends mrope positions for each retained embeddings
|
||||
|
||||
Args:
|
||||
video_embeds_split: Tuple of video embeddings for each video item.
|
||||
video_input: Video input data.
|
||||
|
||||
Returns:
|
||||
Tuple of video embeddings for each video item.
|
||||
Resulting embeddings will have extra 4 channels for
|
||||
computed mrope positions.
|
||||
"""
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
# Cast to long to match the original code
|
||||
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
|
||||
second_per_grid_ts = video_input.get("second_per_grid_ts")
|
||||
if second_per_grid_ts is None:
|
||||
# For Qwen3-VL, second_per_grid_ts might not be available
|
||||
# Use default value of 1.0 for each video
|
||||
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
|
||||
else:
|
||||
second_per_grid_ts = second_per_grid_ts.long()
|
||||
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
|
||||
|
||||
video_embeds_out = []
|
||||
for emb, size, video_second_per_grid_t in zip(
|
||||
video_embeds_split, grid_thw_list, second_per_grid_ts
|
||||
):
|
||||
# For each video, we compute retention mask using EVS
|
||||
retention_mask = compute_retention_mask(
|
||||
emb,
|
||||
size,
|
||||
spatial_merge_size=self.visual.spatial_merge_size,
|
||||
q=self.video_pruning_rate,
|
||||
)
|
||||
|
||||
# Debug logging for EVS pruning
|
||||
logger.debug(
|
||||
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
|
||||
"pruning_rate=%.2f, reduction=%.1f%%)",
|
||||
emb.shape[0],
|
||||
retention_mask.sum().item(),
|
||||
size[0],
|
||||
size[1],
|
||||
size[2],
|
||||
self.video_pruning_rate,
|
||||
(1 - retention_mask.float().mean().item()) * 100,
|
||||
)
|
||||
|
||||
positions = compute_mrope_for_media(
|
||||
size,
|
||||
merge_size,
|
||||
tokens_per_second=tokens_per_second,
|
||||
video_second_per_grid=video_second_per_grid_t.item(),
|
||||
).to(emb.device)
|
||||
|
||||
emb = emb[retention_mask]
|
||||
positions = positions[retention_mask]
|
||||
emb = torch.cat([emb, positions], dim=1)
|
||||
video_embeds_out.append(emb)
|
||||
return tuple(video_embeds_out)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
for input_key in kwargs:
|
||||
@ -1440,6 +1583,20 @@ class Qwen3VLForConditionalGeneration(
|
||||
def iter_mm_grid_hw(
|
||||
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
|
||||
) -> Iterator[tuple[int, int, int]]:
|
||||
"""
|
||||
Iterate over multimodal features and yield grid information.
|
||||
|
||||
For videos with EVS (Efficient Video Sampling) enabled, this function
|
||||
computes the offset based on the pruned token count rather than relying
|
||||
on input_tokens.index(), which would fail when tokens are pruned.
|
||||
|
||||
Args:
|
||||
input_tokens: List of token IDs in the prompt
|
||||
mm_features: List of multimodal feature specifications
|
||||
|
||||
Yields:
|
||||
Tuple of (offset, grid_h, grid_w) for each frame/image
|
||||
"""
|
||||
video_token_id = self.config.video_token_id
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
||||
@ -1452,6 +1609,33 @@ class Qwen3VLForConditionalGeneration(
|
||||
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
||||
llm_grid_h = h // spatial_merge_size
|
||||
llm_grid_w = w // spatial_merge_size
|
||||
|
||||
# Check if EVS (Efficient Video Sampling) is enabled
|
||||
is_evs_enabled = (
|
||||
hasattr(self, "video_pruning_rate")
|
||||
and self.video_pruning_rate is not None
|
||||
and self.video_pruning_rate > 0.0
|
||||
)
|
||||
|
||||
if is_evs_enabled:
|
||||
frame_offsets = self._extract_frame_offsets_from_mask(
|
||||
mm_feature.mm_position, t
|
||||
)
|
||||
if frame_offsets is not None:
|
||||
for rel_offset in frame_offsets:
|
||||
yield offset + rel_offset, llm_grid_h, llm_grid_w
|
||||
continue
|
||||
|
||||
# If EVS is enabled but mask is missing, this indicates a bug
|
||||
# in the prompt processing pipeline. The is_embed mask should
|
||||
# always be present when video_pruning_rate > 0.
|
||||
raise RuntimeError(
|
||||
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
|
||||
"but is_embed mask is missing from mm_position. "
|
||||
"This indicates a bug in prompt processing."
|
||||
)
|
||||
else:
|
||||
# Non-EVS mode: Use original logic with input_tokens.index()
|
||||
for _ in range(t):
|
||||
offset = input_tokens.index(video_token_id, offset)
|
||||
yield offset, llm_grid_h, llm_grid_w
|
||||
@ -1459,35 +1643,255 @@ class Qwen3VLForConditionalGeneration(
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
|
||||
|
||||
def _get_evs_mask_segments(
|
||||
self, mm_position: PlaceholderRange, expected_frames: int
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""Extract contiguous segments from EVS is_embed mask.
|
||||
|
||||
The EVS (Efficient Video Sampling) mask marks which placeholder
|
||||
positions should be filled with video embeddings. This method splits
|
||||
the mask into contiguous segments, where each segment represents one
|
||||
retained frame.
|
||||
|
||||
This is a pure function - it does not modify any state and always
|
||||
returns the same output for the same input (idempotent).
|
||||
|
||||
Args:
|
||||
mm_position: MultiModal position containing the is_embed mask
|
||||
expected_frames: Expected number of frame segments
|
||||
|
||||
Returns:
|
||||
List of tensors, each containing indices for one frame segment,
|
||||
or None if EVS is not enabled or validation fails.
|
||||
"""
|
||||
is_embed_mask = getattr(mm_position, "is_embed", None)
|
||||
if is_embed_mask is None:
|
||||
return None
|
||||
|
||||
# Find all True positions in the mask
|
||||
mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
|
||||
true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
|
||||
if true_indices.numel() == 0:
|
||||
return None
|
||||
|
||||
# Split into contiguous segments (where diff > 1 indicates a gap)
|
||||
if true_indices.numel() == 1:
|
||||
segments = [true_indices]
|
||||
else:
|
||||
diffs = torch.diff(true_indices)
|
||||
split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
|
||||
if split_points.numel() == 0:
|
||||
segments = [true_indices]
|
||||
else:
|
||||
segments = torch.tensor_split(
|
||||
true_indices, split_points.add(1).tolist()
|
||||
)
|
||||
|
||||
# Validate segment count matches expected frames
|
||||
if len(segments) < expected_frames:
|
||||
logger.debug(
|
||||
"EVS mask segments (%d) do not match expected frames (%d)",
|
||||
len(segments),
|
||||
expected_frames,
|
||||
)
|
||||
return None
|
||||
|
||||
return segments[:expected_frames]
|
||||
|
||||
def _extract_frame_offsets_from_mask(
|
||||
self, mm_position: PlaceholderRange, expected_frames: int
|
||||
) -> list[int] | None:
|
||||
"""Return relative offsets for each EVS-retained frame.
|
||||
|
||||
The prompt processor stores a boolean mask inside ``mm_position`` that
|
||||
marks which placeholder locations should be populated with video
|
||||
embeddings. By splitting that mask into contiguous runs we can recover
|
||||
the start of every retained frame without probing ``input_tokens``.
|
||||
|
||||
Args:
|
||||
mm_position: MultiModal position containing the is_embed mask
|
||||
expected_frames: Expected number of frames
|
||||
|
||||
Returns:
|
||||
List of starting offsets (relative to mm_position) for each frame,
|
||||
or None if EVS is not enabled.
|
||||
"""
|
||||
segments = self._get_evs_mask_segments(mm_position, expected_frames)
|
||||
if segments is None:
|
||||
return None
|
||||
|
||||
return [int(segment[0].item()) for segment in segments]
|
||||
|
||||
def _get_actual_frame_token_counts(
|
||||
self, mm_position: PlaceholderRange, expected_frames: int
|
||||
) -> list[int] | None:
|
||||
"""Return actual token count for each EVS-retained frame.
|
||||
|
||||
This function calculates the actual number of tokens per frame by
|
||||
analyzing the is_embed mask, accounting for EVS pruning. Each frame
|
||||
may have a different token count due to content-aware pruning.
|
||||
|
||||
Args:
|
||||
mm_position: MultiModal position containing the is_embed mask
|
||||
expected_frames: Expected number of frames
|
||||
|
||||
Returns:
|
||||
List of token counts for each frame, or None if EVS is not enabled.
|
||||
"""
|
||||
segments = self._get_evs_mask_segments(mm_position, expected_frames)
|
||||
if segments is None:
|
||||
return None
|
||||
|
||||
return [len(seg) for seg in segments]
|
||||
|
||||
def recompute_mrope_positions(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions (starting with
|
||||
num_computed_tokens index). Original mrope_positions are computed
|
||||
for unpruned sequence and becomes incorrect once pruning occurs,
|
||||
so once we prune media tokens we should reflect this in the
|
||||
mrope_positions before we feed it to LLM.
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (Containing
|
||||
entire sequence).
|
||||
multimodal_embeddings: Tuple of multimodal embeddings.
|
||||
mrope_positions: Existing mrope positions (3, N) for entire
|
||||
sequence
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
|
||||
Returns:
|
||||
Tuple of (multimodal_embeddings, mrope_positions,
|
||||
mrope_position_delta).
|
||||
"""
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
# Device
|
||||
device = (
|
||||
multimodal_embeddings[0].device
|
||||
if len(multimodal_embeddings)
|
||||
else mrope_positions.device
|
||||
)
|
||||
|
||||
# Tensors
|
||||
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
|
||||
|
||||
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
|
||||
mm_embeddings_pos = [
|
||||
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
|
||||
]
|
||||
|
||||
positions, mrope_positions_delta = recompute_mrope_positions(
|
||||
input_ids_t,
|
||||
mm_embeddings_pos,
|
||||
mrope_positions,
|
||||
num_computed_tokens,
|
||||
vision_start_token_id,
|
||||
image_token_id,
|
||||
video_token_id,
|
||||
)
|
||||
|
||||
return tuple(mm_embeddings_out), positions, mrope_positions_delta
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
# Pre-collect actual frame token counts for EVS mode
|
||||
frame_token_counts_map = {}
|
||||
for mm_feature in mm_features:
|
||||
if mm_feature.modality == "video":
|
||||
is_evs_enabled = (
|
||||
hasattr(self, "video_pruning_rate")
|
||||
and self.video_pruning_rate is not None
|
||||
and self.video_pruning_rate > 0.0
|
||||
)
|
||||
if is_evs_enabled:
|
||||
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
|
||||
token_counts = self._get_actual_frame_token_counts(
|
||||
mm_feature.mm_position, t
|
||||
)
|
||||
assert token_counts is not None, (
|
||||
"EVS enabled but failed to extract frame token counts "
|
||||
"from is_embed mask"
|
||||
)
|
||||
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
|
||||
|
||||
llm_pos_ids_list = []
|
||||
st = 0
|
||||
frame_counts_idx = {}
|
||||
|
||||
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
|
||||
input_tokens, mm_features
|
||||
):
|
||||
text_len = offset - st
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
|
||||
# Determine actual token count for this frame
|
||||
base_offset = None
|
||||
for feat_offset in frame_token_counts_map:
|
||||
if offset >= feat_offset:
|
||||
base_offset = feat_offset
|
||||
|
||||
if base_offset is not None:
|
||||
# EVS mode: use actual token count from is_embed mask
|
||||
assert base_offset in frame_token_counts_map, (
|
||||
f"Found base_offset {base_offset} but not in frame_token_counts_map"
|
||||
)
|
||||
|
||||
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
|
||||
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
|
||||
st = offset + llm_grid_h * llm_grid_w
|
||||
if base_offset not in frame_counts_idx:
|
||||
frame_counts_idx[base_offset] = 0
|
||||
|
||||
counts = frame_token_counts_map[base_offset]
|
||||
idx = frame_counts_idx[base_offset]
|
||||
|
||||
assert idx < len(counts), (
|
||||
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
|
||||
)
|
||||
|
||||
actual_frame_tokens = counts[idx]
|
||||
frame_counts_idx[base_offset] += 1
|
||||
else:
|
||||
# Non-EVS mode (or image): use theoretical grid size
|
||||
actual_frame_tokens = llm_grid_h * llm_grid_w
|
||||
|
||||
# Add text segment
|
||||
text_positions = (
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(text_positions)
|
||||
st_idx += text_len
|
||||
|
||||
# Add frame segment with actual token count (not theoretical)
|
||||
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
|
||||
# Only take the first actual_frame_tokens positions
|
||||
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
|
||||
llm_pos_ids_list.append(frame_positions)
|
||||
|
||||
# Update st using actual token count
|
||||
st = offset + actual_frame_tokens
|
||||
|
||||
# Handle final text segment
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
final_text_positions = (
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(final_text_positions)
|
||||
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
@ -1508,9 +1912,17 @@ class Qwen3VLForConditionalGeneration(
|
||||
multimodal_input = mm_input_by_modality[modality]
|
||||
if modality == "image":
|
||||
image_embeddings = self._process_image_input(multimodal_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
image_embeddings = self._postprocess_image_embeds_evs(
|
||||
image_embeddings, multimodal_input
|
||||
)
|
||||
multimodal_embeddings += tuple(image_embeddings)
|
||||
if modality == "video":
|
||||
video_embeddings = self._process_video_input(multimodal_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
video_embeddings = self._postprocess_video_embeds_evs(
|
||||
video_embeddings, multimodal_input
|
||||
)
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user