From 39b643dc1ac68bb595a9e0a3ea0834e9adc00e92 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 Oct 2025 13:38:29 +0800 Subject: [PATCH] [Model] Use `merge_by_field_config` for MM models (G) (#26117) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/gemma3_mm.py | 35 ++++------ vllm/model_executor/models/gemma3n_mm.py | 69 +++++++++----------- vllm/model_executor/models/glm4_1v.py | 38 +---------- vllm/model_executor/models/glm4v.py | 15 ++--- vllm/model_executor/models/granite_speech.py | 7 +- 5 files changed, 56 insertions(+), 108 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 36f8651371ba..b6aa78ac53e0 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): processor=hf_processor) for size in image_sizes ] - processed_outputs["num_crops"] = torch.tensor(num_crops) + processed_outputs["num_patches"] = torch.tensor(num_crops) + 1 return processed_outputs @@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops + 1), - num_crops=MultiModalFieldConfig.batched("image"), + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module): dummy_inputs=Gemma3DummyInputsBuilder) class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - num_crops = kwargs.pop("num_crops", None) + num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") - image_size = self.config.vision_config.image_size - return Gemma3ImagePixelInputs( - pixel_values=flatten_bn(pixel_values, concat=True), - num_patches=flatten_bn(num_crops, concat=True) + 1, - resolve_bindings={ - "h": image_size, - "w": image_size - }) + return Gemma3ImagePixelInputs(pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={ + "h": image_size, + "w": image_size + }) def _image_pixels_to_features( self, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 101e083ac123..83b9d7fa4133 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Any, Literal, Optional, Union, cast import numpy as np import torch @@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription) @@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256 TOKENS_PER_AUDIO = 188 -class Gemma3nImagePixelInputs(TypedDict): - pixel_values: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" +class Gemma3nImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each patch + - w: Width of each patch + """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class Gemma3nAudioInputs(TypedDict): - input_features: Union[torch.Tensor, list[torch.Tensor]] - input_features_padded: torch.Tensor - """Shape: `(batch_size * num_audio, seq_length, num_features)`""" - input_features_mask: torch.Tensor - """Shape: `(batch_size * num_audio, seq_length)`""" +class Gemma3nAudioInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - s: seq_length + - f: num_features + """ + type: Literal["audio"] = "audio" + input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")] + input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")] Gemma3nImageInputs = Gemma3nImagePixelInputs @@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] return dict( pixel_values=MultiModalFieldConfig.batched("image"), - input_features=MultiModalFieldConfig.batched("audio"), input_features_padded=MultiModalFieldConfig.batched("audio"), - input_features_mask=MultiModalFieldConfig.batched("audio")) + input_features_mask=MultiModalFieldConfig.batched("audio"), + ) def _get_prompt_updates( self, @@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module): dummy_inputs=Gemma3nDummyInputsBuilder) class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsTranscription): + merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { @@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, device=self.language_model.model.embed_tokens.weight.device, dtype=self.language_model.model.embed_tokens.weight.dtype) - @property - def dtype(self): - return next(self.parameters()).dtype - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - # TODO check if there are any - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Gemma3nImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - pixel_values = pixel_values.contiguous() - - return Gemma3nImagePixelInputs( - pixel_values=self._validate_pixel_values(pixel_values), ) + return Gemma3nImagePixelInputs(pixel_values=pixel_values) def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Gemma3nAudioInputs]: - input_features = kwargs.pop("input_features", None) - if input_features is None: + + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: return None input_features_mask = kwargs.pop("input_features_mask", None) if input_features_mask is None: return None - input_features_padded = kwargs.pop("input_features_padded", None) - if input_features_padded is None: - return None - return Gemma3nAudioInputs( - input_features=input_features, - input_features_mask=input_features_mask, input_features_padded=input_features_padded, + input_features_mask=input_features_mask, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, ) and "image" not in mm_input_by_modality: mm_input_by_modality[ "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key == "input_features" \ + if input_key == "input_features_padded" \ and "audio" not in mm_input_by_modality: mm_input_by_modality[ "audio"] = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index e6e294a14349..5b64941762c6 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1319,6 +1319,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): ) class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1381,22 +1383,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Glm4vImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -1407,11 +1393,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return Glm4vImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1419,11 +1400,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return Glm4vImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1440,11 +1416,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Glm4vVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1452,11 +1423,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Glm4vVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 22ddb1d75160..213c3b2769eb 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -43,7 +43,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import flatten_bn class GLMVImagePixelInputs(TensorSchema): @@ -529,8 +528,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, info=GLM4VProcessingInfo, dummy_inputs=GLM4VDummyInputsBuilder) -class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, - SupportsMultiModal): +class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, + SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "query_key_value": ["query_key_value"], @@ -574,14 +574,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, pixel_values = kwargs.pop("pixel_values", None) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - expected_h = expected_w = self.config.vision_config["image_size"] return GLMVImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values, - concat=True), + data=pixel_values, resolve_bindings={ "h": expected_h, "w": expected_w @@ -598,6 +593,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.transformer + get_input_embeddings = SupportsMultiModal.get_input_embeddings + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 0ec451356f5e..ea9f67723b12 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor( # Calculate the number of audio tokens per entry in the batch; # This is used to split the batch back out after padding. audio_token_index = self.info.get_hf_config().audio_token_index - processed_outputs["audio_embed_sizes"] = [ - torch.sum(indices == audio_token_index).item() - for indices in processed_outputs["input_ids"] - ] + processed_outputs["audio_embed_sizes"] = ( + processed_outputs["input_ids"] == audio_token_index).sum(-1) return processed_outputs @@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration( SupportsPP, SupportsLoRA, ): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": [