[V1] Support Mistral3 in V1 (#15950)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-04-02 16:36:24 -06:00 committed by GitHub
parent 1cab43c2d2
commit f021b97993
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 7 deletions

View File

@ -888,7 +888,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
*
* ✅︎
*
* ✅︎
- * `MllamaForConditionalGeneration`
* Llama 3.2
* T + I<sup>+</sup>

View File

@ -31,12 +31,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info, select_patch_features
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class Mistral3ImagePixelInputs(TypedDict):
@ -425,7 +425,7 @@ def init_vision_tower_for_llava(
info=_build_mistral3_info,
dummy_inputs=Mistral3DummyInputsBuilder)
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only):
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -518,7 +518,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
embed_is_patch=flatten_bn(embed_is_patch),
)
def _process_image_input(
@ -557,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
return scatter_patch_features(
vision_embeddings,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,