mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 22:23:12 +08:00
[Misc] Allow LM only loading for Pixtral (#29451)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
d9d342d214
commit
452a7c9f7c
@ -400,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5)
|
||||
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
self.patch_merger = PatchMerger(
|
||||
vision_encoder_dim=self.vision_args.hidden_size,
|
||||
spatial_merge_size=self.vision_args.spatial_merge_size,
|
||||
use_mlp_bias=False,
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
self.pre_mm_projector_norm = (
|
||||
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm
|
||||
else None
|
||||
)
|
||||
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
self.vision_args, dim=config.text_config.hidden_size
|
||||
)
|
||||
self.patch_merger = (
|
||||
PatchMerger(
|
||||
vision_encoder_dim=self.vision_args.hidden_size,
|
||||
spatial_merge_size=self.vision_args.spatial_merge_size,
|
||||
use_mlp_bias=False,
|
||||
)
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE
|
||||
else None
|
||||
)
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
self.vision_args, dim=config.text_config.hidden_size
|
||||
)
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
self.pre_mm_projector_norm = None
|
||||
self.patch_merger = None
|
||||
self.vision_language_adapter = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@ -436,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
self,
|
||||
image_input: PixtralImagePixelInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert (
|
||||
self.vision_encoder is not None and self.vision_language_adapter is not None
|
||||
)
|
||||
|
||||
images = image_input["images"]
|
||||
image_features = self.vision_encoder(images)
|
||||
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
|
||||
image_features = torch.cat(image_features)
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
if self.pre_mm_projector_norm is not None:
|
||||
image_features = self.pre_mm_projector_norm(image_features)
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
if self.patch_merger is not None:
|
||||
patch_size = self.vision_args.patch_size
|
||||
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
||||
img_patch_dims = [
|
||||
@ -508,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
return weight[0].startswith("pre_mm_projector_norm")
|
||||
|
||||
# Get references to parameters for direct loading
|
||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
||||
vision_encoder_dict = (
|
||||
dict(self.vision_encoder.named_parameters())
|
||||
if self.vision_encoder is not None
|
||||
else {}
|
||||
)
|
||||
patch_merger_dict = (
|
||||
dict(self.patch_merger.named_parameters())
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE
|
||||
else dict()
|
||||
if self.patch_merger is not None
|
||||
else {}
|
||||
)
|
||||
pre_mm_projector_norm_dict = (
|
||||
dict(self.pre_mm_projector_norm.named_parameters())
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm
|
||||
else dict()
|
||||
if self.pre_mm_projector_norm is not None
|
||||
else {}
|
||||
)
|
||||
vision_lang_adapter_dict = (
|
||||
dict(self.vision_language_adapter.named_parameters())
|
||||
if self.vision_language_adapter is not None
|
||||
else {}
|
||||
)
|
||||
vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters())
|
||||
|
||||
def llm_weights_generator():
|
||||
# Single pass over weights
|
||||
for name, w in weights:
|
||||
if is_vision_encoder_weights((name, w)):
|
||||
if self.vision_encoder is None:
|
||||
continue
|
||||
# Load vision encoder weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = vision_encoder_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_patch_merger((name, w)):
|
||||
if self.patch_merger is None:
|
||||
continue
|
||||
# Load vision patch merger weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = patch_merger_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_pre_mm_projector_norm((name, w)):
|
||||
if self.pre_mm_projector_norm is None:
|
||||
continue
|
||||
# Load vision pre_mm_projector_norm weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = pre_mm_projector_norm_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_lang_adapter_weights((name, w)):
|
||||
if self.vision_language_adapter is None:
|
||||
continue
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = vision_lang_adapter_dict[trimmed_name]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user