From 424fb7a5d22c013ec0ef6244c62cd75ed076375a Mon Sep 17 00:00:00 2001 From: bppps <44322223+bppps@users.noreply.github.com> Date: Sat, 23 Aug 2025 00:56:46 +0800 Subject: [PATCH] =?UTF-8?q?[BugFix]=20Fix=20the=20issue=20where=20image=20?= =?UTF-8?q?embeddings=20were=20incorrectly=20split.=E2=80=A6=20(#23366)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: bppps Co-authored-by: zouyu.zzx Co-authored-by: bppps --- vllm/model_executor/models/glm4_1v.py | 7 +- .../models/qwen2_5_omni_thinker.py | 80 ++++++++++++------- vllm/model_executor/models/qwen2_vl.py | 64 ++++++++++----- 3 files changed, 99 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 08252c51310be..662728e6b1393 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .qwen2_vl import (_create_qwen2vl_field_factory, + apply_rotary_pos_emb_vision) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -1153,7 +1154,9 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _get_prompt_updates( self, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5aadebc33324c..664e3f2985a59 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, + MultiModalFieldConfig]]: - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, + torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - num_videos = len(video_grid_sizes) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) - return dict( - input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - audio_feature_lengths=MultiModalFieldConfig.batched("audio"), - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), - use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), - ) + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared( + "video", num_videos), + ) + + return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(self._spatial_merge_size, *args, **kwargs) + def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): required_fields={ "input_audio_features", "audio_feature_lengths" }, - fields_factory=_qwen2_5_omni_thinker_field_config, + fields_factory=create_qwen2_5_omni_thinker_field_factory( + self._spatial_merge_size), ) return super()._parse_audio_data(data) @@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config( + ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _maybe_apply_prompt_updates( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2315fe2ab92b5..ae7a8d8d7a5b9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -699,29 +699,46 @@ class Qwen2VisionTransformer(nn.Module): return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) +def _create_qwen2vl_field_factory( + spatial_merge_size: int +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) + def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -731,7 +748,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -745,7 +763,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_video_data(data) @@ -967,7 +986,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size) def _get_prompt_updates( self, @@ -1010,7 +1030,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,