From 452a7c9f7c949cd20c3c0c81cd4352b2a0045076 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 26 Nov 2025 05:00:00 -0800 Subject: [PATCH] [Misc] Allow LM only loading for Pixtral (#29451) Signed-off-by: Roger Wang --- vllm/model_executor/models/pixtral.py | 73 +++++++++++++++++++-------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6011d93a795d1..3464de472add5 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -400,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) prefix=maybe_prefix(prefix, "language_model"), ) - self.vision_encoder = VisionTransformer(self.vision_args) - - if self.vision_args.add_pre_mm_projector_layer_norm: - self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5) - - if self.vision_args.mm_projector_id == PATCH_MERGE: - self.patch_merger = PatchMerger( - vision_encoder_dim=self.vision_args.hidden_size, - spatial_merge_size=self.vision_args.spatial_merge_size, - use_mlp_bias=False, + if multimodal_config.get_limit_per_prompt("image"): + self.vision_encoder = VisionTransformer(self.vision_args) + self.pre_mm_projector_norm = ( + RMSNorm(self.vision_args.hidden_size, eps=1e-5) + if self.vision_args.add_pre_mm_projector_layer_norm + else None ) - - self.vision_language_adapter = VisionLanguageAdapter( - self.vision_args, dim=config.text_config.hidden_size - ) + self.patch_merger = ( + PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) + if self.vision_args.mm_projector_id == PATCH_MERGE + else None + ) + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=config.text_config.hidden_size + ) + else: + self.vision_encoder = None + self.pre_mm_projector_norm = None + self.patch_merger = None + self.vision_language_adapter = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors @@ -436,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) self, image_input: PixtralImagePixelInputs, ) -> tuple[torch.Tensor, ...]: + assert ( + self.vision_encoder is not None and self.vision_language_adapter is not None + ) + images = image_input["images"] image_features = self.vision_encoder(images) feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_features = torch.cat(image_features) - if self.vision_args.add_pre_mm_projector_layer_norm: + if self.pre_mm_projector_norm is not None: image_features = self.pre_mm_projector_norm(image_features) - if self.vision_args.mm_projector_id == PATCH_MERGE: + if self.patch_merger is not None: patch_size = self.vision_args.patch_size spatial_merge_size_square = self.vision_args.spatial_merge_size**2 img_patch_dims = [ @@ -508,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return weight[0].startswith("pre_mm_projector_norm") # Get references to parameters for direct loading - vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + vision_encoder_dict = ( + dict(self.vision_encoder.named_parameters()) + if self.vision_encoder is not None + else {} + ) patch_merger_dict = ( dict(self.patch_merger.named_parameters()) - if self.vision_args.mm_projector_id == PATCH_MERGE - else dict() + if self.patch_merger is not None + else {} ) pre_mm_projector_norm_dict = ( dict(self.pre_mm_projector_norm.named_parameters()) - if self.vision_args.add_pre_mm_projector_layer_norm - else dict() + if self.pre_mm_projector_norm is not None + else {} + ) + vision_lang_adapter_dict = ( + dict(self.vision_language_adapter.named_parameters()) + if self.vision_language_adapter is not None + else {} ) - vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters()) def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): + if self.vision_encoder is None: + continue # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): + if self.patch_merger is None: + continue # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): + if self.pre_mm_projector_norm is None: + continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): + if self.vision_language_adapter is None: + continue # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict[trimmed_name]