From cbd5e07a513759892554a7af5a49432bc8dd3386 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 27 Oct 2025 13:38:05 +0800 Subject: [PATCH] [Model] Use merge_by_field_config for MM models (Qwen series) (#27546) Signed-off-by: DarkLight1337 --- .../models/qwen2_5_omni_thinker.py | 95 +++---------------- vllm/model_executor/models/qwen2_5_vl.py | 49 +--------- vllm/model_executor/models/qwen2_audio.py | 25 +---- vllm/model_executor/models/qwen2_vl.py | 48 +--------- .../models/qwen3_omni_moe_thinker.py | 42 ++------ vllm/model_executor/models/qwen3_vl.py | 64 +------------ vllm/model_executor/models/qwen_vl.py | 18 +--- 7 files changed, 36 insertions(+), 305 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a5d6004faf381..6338ea93b8c8a 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -126,12 +126,12 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] input_features: Annotated[ torch.Tensor | list[torch.Tensor], - TensorShape("nmb", "tsl"), + TensorShape("nmb", "tsl", dynamic_dims={"tsl"}), ] feature_attention_mask: Annotated[ - torch.Tensor, - TensorShape("na", "msl"), + torch.Tensor | list[torch.Tensor], + TensorShape("na", "msl", dynamic_dims={"msl"}), ] @@ -651,18 +651,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( class Qwen2_5OmniConditionalGenerationMixin: - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str, dim: int = 0 - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if dim == 0: - return mm_input.reshape(-1, *mm_input.shape[2:]) - return torch.concat(list(mm_input), dim=dim) - else: - return torch.concat(mm_input, dim=dim) - def _parse_and_validate_audio_input( self, **kwargs: object ) -> Qwen2_5OmniAudioFeatureInputs | None: @@ -671,18 +659,7 @@ class Qwen2_5OmniConditionalGenerationMixin: feature_attention_mask = kwargs.pop("feature_attention_mask", None) if input_audio_features is None: return None - input_audio_features = self._validate_and_reshape_mm_tensor( - input_audio_features, "input_audio_features", dim=1 - ) - if feature_attention_mask is not None: - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, "feature_attention_mask" - ) - if not isinstance(input_audio_features, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio input features. " - f"Got type: {type(input_audio_features)}" - ) + return Qwen2_5OmniAudioFeatureInputs( type="audio_features", input_features=input_audio_features, @@ -702,19 +679,6 @@ class Qwen2_5OmniConditionalGenerationMixin: return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}" - ) - return Qwen2_5_VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -722,18 +686,6 @@ class Qwen2_5OmniConditionalGenerationMixin: ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - - if not isinstance(image_embeds, torch.Tensor): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -752,13 +704,6 @@ class Qwen2_5OmniConditionalGenerationMixin: return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -766,13 +711,6 @@ class Qwen2_5OmniConditionalGenerationMixin: ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - if not isinstance(video_embeds, torch.Tensor): raise ValueError( "Incorrect type of video embeddings. " @@ -787,23 +725,18 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, audio_input: Qwen2_5OmniAudioFeatureInputs, - audio_hashes: list[str] = None, - cached_audio_features: torch.Tensor = None, + audio_hashes: list[str] | None = None, + cached_audio_features: torch.Tensor | None = None, ) -> torch.Tensor: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] - if input_features.ndim == 3: - assert input_features.shape[0] == 1 - input_features = input_features.squeeze(0) - if audio_feature_lengths.ndim == 2: - assert ( - audio_feature_lengths.shape[0] == 1 - or audio_feature_lengths.shape[1] == 1 - ) - if audio_feature_lengths.shape[0] == 1: - audio_feature_lengths = audio_feature_lengths.squeeze(0) - else: - audio_feature_lengths = audio_feature_lengths.squeeze(1) + + if audio_feature_lengths.shape[0] == 1: + audio_feature_lengths = audio_feature_lengths.squeeze(0) + elif audio_feature_lengths.shape[1] == 1: + audio_feature_lengths = audio_feature_lengths.squeeze(1) + else: + raise AssertionError(audio_feature_lengths.shape) audio_feat_lengths, audio_output_lengths = ( self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) @@ -867,6 +800,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, ): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a3436201a1db6..b622021e225ca 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1071,6 +1071,8 @@ class Qwen2_5_VLForConditionalGeneration( SupportsMultiModalPruning, SupportsMRoPE, ): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -1273,24 +1275,6 @@ class Qwen2_5_VLForConditionalGeneration( num_layers = len(self.language_model.model.layers) return (2, num_layers // 2, num_layers - 3) - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError( - f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})" - ) - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object ) -> Qwen2_5_VLImageInputs | None: @@ -1302,13 +1286,6 @@ class Qwen2_5_VLForConditionalGeneration( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - return Qwen2_5_VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1316,13 +1293,6 @@ class Qwen2_5_VLForConditionalGeneration( ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1341,14 +1311,6 @@ class Qwen2_5_VLForConditionalGeneration( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: - second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1357,13 +1319,6 @@ class Qwen2_5_VLForConditionalGeneration( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 553fdc4a9e179..4de6a19c1ff0c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -313,6 +313,8 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing dummy_inputs=Qwen2AudioDummyInputsBuilder, ) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): @@ -346,16 +348,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports self.language_model.make_empty_intermediate_tensors ) - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - return mm_input.reshape(-1, *mm_input.shape[2:]) - else: - return torch.concat(mm_input) - def _parse_and_validate_audio_input( self, **kwargs: object ) -> Qwen2AudioInputs | None: @@ -367,24 +359,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports return None if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}" - ) - audio_embeds = self._validate_and_reshape_mm_tensor( - audio_embeds, "audio_embeds" - ) return Qwen2AudioEmbeddingInputs( type="audio_embeds", audio_embeds=audio_embeds ) if input_features is not None: - input_features = self._validate_and_reshape_mm_tensor( - input_features, "input_features" - ) - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, "feature_attention_mask" - ) return Qwen2AudioFeatureInputs( type="audio_features", input_features=input_features, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 47ce3ee744edd..f0d7e2e7d7eca 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1213,6 +1213,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]) class Qwen2VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): + merge_by_field_config = True + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -1406,24 +1408,6 @@ class Qwen2VLForConditionalGeneration( self.language_model.make_empty_intermediate_tensors ) - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError( - f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})" - ) - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object ) -> Qwen2VLImageInputs | None: @@ -1435,13 +1419,6 @@ class Qwen2VLForConditionalGeneration( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - return Qwen2VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1449,13 +1426,6 @@ class Qwen2VLForConditionalGeneration( ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - return Qwen2VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1473,13 +1443,6 @@ class Qwen2VLForConditionalGeneration( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - return Qwen2VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1487,13 +1450,6 @@ class Qwen2VLForConditionalGeneration( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - return Qwen2VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 89ce0068fb1ab..f3b6ad495db42 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -63,10 +63,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioFeatureInputs, - Qwen2AudioProcessingInfo, -) +from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems @@ -86,6 +83,7 @@ from .interfaces import ( SupportsPP, ) from .qwen2_5_omni_thinker import ( + Qwen2_5OmniAudioFeatureInputs, Qwen2_5OmniConditionalGenerationMixin, Qwen2_5OmniThinkerDummyInputsBuilder, Qwen2_5OmniThinkerMultiModalProcessor, @@ -101,6 +99,7 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, + flatten_bn, maybe_prefix, ) from .vision import ( @@ -1056,41 +1055,16 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin): - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str, dim: int = 0 - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if name == "feature_attention_mask": - dim = -1 - if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input), dim=dim) - else: - if isinstance(mm_input[0], list): - return torch.concat( - [torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))], - dim=dim, - ) - else: - return torch.concat(mm_input, dim=dim) - def _process_audio_input( self, - audio_input: Qwen2AudioFeatureInputs, - audio_hashes: list[str] = None, - cached_audio_features: torch.Tensor = None, + audio_input: Qwen2_5OmniAudioFeatureInputs, + audio_hashes: list[str] | None = None, + cached_audio_features: torch.Tensor | None = None, ) -> torch.Tensor: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] - if input_features.ndim == 3: - assert input_features.shape[0] == 1 - input_features = input_features.squeeze(0) - - if not isinstance(audio_feature_lengths, torch.Tensor): - audio_feature_lengths = torch.cat(audio_feature_lengths) - if audio_feature_lengths.ndim == 2: - audio_feature_lengths = audio_feature_lengths.reshape(-1) + audio_feature_lengths = flatten_bn(audio_feature_lengths, concat=True) audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( audio_feature_lengths @@ -1117,6 +1091,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, ): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 940fa50ff8035..10c0eb4eb65ea 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1175,6 +1175,8 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM): class Qwen3VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1298,24 +1300,6 @@ class Qwen3VLForConditionalGeneration( for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].zero_() - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError( - f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})" - ) - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object ) -> Qwen2_5_VLImageInputs | None: @@ -1327,19 +1311,6 @@ class Qwen3VLForConditionalGeneration( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}" - ) - return Qwen2_5_VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1347,18 +1318,6 @@ class Qwen3VLForConditionalGeneration( ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds" - ) - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw" - ) - - if not isinstance(image_embeds, torch.Tensor): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1377,13 +1336,6 @@ class Qwen3VLForConditionalGeneration( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1392,18 +1344,6 @@ class Qwen3VLForConditionalGeneration( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds" - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw" - ) - - if not isinstance(video_embeds, torch.Tensor): - raise ValueError( - "Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}" - ) return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index f011229985c87..cf74f72fe633d 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -58,7 +58,6 @@ from .interfaces import ( SupportsPP, ) from .qwen import QWenBaseModel, QWenModel -from .utils import flatten_bn class QwenImagePixelInputs(TensorSchema): @@ -703,6 +702,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): class QwenVLForConditionalGeneration( QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal ): + merge_by_field_config = True + packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -750,30 +751,19 @@ class QwenVLForConditionalGeneration( image_embeds = kwargs.pop("image_embeds", None) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of pixel values. Got type: {type(pixel_values)}" - ) - expected_h = expected_w = self.config.visual["image_size"] resolve_bindings = {"h": expected_h, "w": expected_w} return QwenImagePixelInputs( type="pixel_values", - data=flatten_bn(pixel_values, concat=True), + data=pixel_values, resolve_bindings=resolve_bindings, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) - return QwenImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) return None