[VLM] Update Qwen3-VL max_num_video_tokens calculation for configurable video profiling (#25557)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Isotr0py 2025-09-28 12:21:01 +08:00 committed by yewentao256
parent 495f368238
commit 6dee906d2c
2 changed files with 74 additions and 9 deletions

View File

@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_FRAMES_PER_VIDEO = 14
# === Vision Inputs === #
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=1,
image_processor=image_processor,
)
return num_image_tokens
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=1,
image_processor=None,
)
return max_image_size
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 1) -> int:
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
num_frames = start_num_frames
while True:
next_num_frames = num_frames + 1
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
self,
seq_len: int,
mm_counts: Mapping[str, int],
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
) -> int:
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
max_frames_per_video)
return max(max_frames_per_video, 1)

View File

@ -33,11 +33,14 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize)
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
Qwen3VLVideoProcessor)
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig, Qwen3VLVisionConfig)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability
@ -85,6 +88,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576
class Qwen3_VisionPatchEmbed(nn.Module):
@ -593,11 +599,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessorFast],
image_processor: Optional[Union[Qwen2VLImageProcessorFast,
Qwen3VLVideoProcessor]],
) -> tuple[ImageSize, int]:
if image_processor is None:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
@ -605,12 +616,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
**extra_kwargs,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
@ -629,6 +650,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(max_tokens,
start_num_frames=start_num_frames)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
video_fps: float, merge_size: int):
if not isinstance(indices, list):
@ -698,6 +752,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts)
target_video_size, _ = self.info._get_vision_info(
image_width=target_width,
image_height=target_height,
num_frames=target_num_frames,
image_processor=self.info.get_video_processor(),
)
return {
"image":
self._get_dummy_images(width=target_width,
@ -705,8 +765,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
width=target_video_size.width,
height=target_video_size.height,
num_frames=target_num_frames,
num_videos=num_videos,
),