diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index dae98093bc6e1..d8c0234b8f42d 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -760,6 +760,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() vision_embeddings_flat = self.vision_model(flat_data) + vision_embeddings_flat = self.multi_modal_projector( + vision_embeddings_flat) return vision_embeddings_flat.split(patches_per_image, dim=0) def get_multimodal_embeddings(self, @@ -791,10 +793,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - multimodal_embeddings = torch.cat(multimodal_embeddings) - mm_embeddings = self.multi_modal_projector(multimodal_embeddings) inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, select_patch_features(mm_embeddings), + input_ids, inputs_embeds, + select_patch_features(multimodal_embeddings), self.config.image_token_index) return inputs_embeds