mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
Support LoRA for Mistral3 (#17428)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
88fcf00dda
commit
a44c4f1d2f
@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* Mistral3
|
||||
* T + I<sup>+</sup>
|
||||
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `MllamaForConditionalGeneration`
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
|
||||
_build_mistral3_processor,
|
||||
info=_build_mistral3_info,
|
||||
dummy_inputs=Mistral3DummyInputsBuilder)
|
||||
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user