From a44c4f1d2f7cb882e0045b0c7d7cbcf8e08ef9bd Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 29 Apr 2025 22:10:30 -0600 Subject: [PATCH] Support LoRA for Mistral3 (#17428) Signed-off-by: mgoin --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/mistral3.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 95e7d5d602ac..8489ebe713a3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ * Mistral3 * T + I+ * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. - * + * ✅︎ * ✅︎ * ✅︎ - * `MllamaForConditionalGeneration` diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 12c87dc0f2af..c9abe4142be5 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -382,8 +384,8 @@ def init_vision_tower_for_llava( _build_mistral3_processor, info=_build_mistral3_info, dummy_inputs=Mistral3DummyInputsBuilder) -class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, + SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower")