[Misc] Allow LM only loading for Pixtral (#29451)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-11-26 05:00:00 -08:00 committed by GitHub
parent d9d342d214
commit 452a7c9f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -400,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_encoder = VisionTransformer(self.vision_args)
if self.vision_args.add_pre_mm_projector_layer_norm:
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5)
if self.vision_args.mm_projector_id == PATCH_MERGE:
self.patch_merger = PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
if multimodal_config.get_limit_per_prompt("image"):
self.vision_encoder = VisionTransformer(self.vision_args)
self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
if self.vision_args.add_pre_mm_projector_layer_norm
else None
)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size
)
self.patch_merger = (
PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
if self.vision_args.mm_projector_id == PATCH_MERGE
else None
)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size
)
else:
self.vision_encoder = None
self.pre_mm_projector_norm = None
self.patch_merger = None
self.vision_language_adapter = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@ -436,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
assert (
self.vision_encoder is not None and self.vision_language_adapter is not None
)
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
image_features = torch.cat(image_features)
if self.vision_args.add_pre_mm_projector_layer_norm:
if self.pre_mm_projector_norm is not None:
image_features = self.pre_mm_projector_norm(image_features)
if self.vision_args.mm_projector_id == PATCH_MERGE:
if self.patch_merger is not None:
patch_size = self.vision_args.patch_size
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
img_patch_dims = [
@ -508,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
vision_encoder_dict = (
dict(self.vision_encoder.named_parameters())
if self.vision_encoder is not None
else {}
)
patch_merger_dict = (
dict(self.patch_merger.named_parameters())
if self.vision_args.mm_projector_id == PATCH_MERGE
else dict()
if self.patch_merger is not None
else {}
)
pre_mm_projector_norm_dict = (
dict(self.pre_mm_projector_norm.named_parameters())
if self.vision_args.add_pre_mm_projector_layer_norm
else dict()
if self.pre_mm_projector_norm is not None
else {}
)
vision_lang_adapter_dict = (
dict(self.vision_language_adapter.named_parameters())
if self.vision_language_adapter is not None
else {}
)
vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters())
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
if self.vision_encoder is None:
continue
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
if self.patch_merger is None:
continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
if self.pre_mm_projector_norm is None:
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
if self.vision_language_adapter is None:
continue
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]