Add Mistral Large 3 and Ministral 3 (#29757)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Julien Denize 2025-12-02 11:29:00 +01:00 committed by GitHub
parent 8bbcf8b6e7
commit d8c6210eea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 724 additions and 30 deletions

View File

@ -417,7 +417,8 @@ th {
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ |
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
@ -711,7 +712,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ |
| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |

View File

@ -358,6 +358,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MistralLarge3ForCausalLM": _HfExamplesInfo(
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
),
"MixtralForCausalLM": _HfExamplesInfo(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
{"tiny": "TitanML/tiny-mixtral"},
@ -770,7 +773,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"PixtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Pixtral-12B-2409",
extras={
"mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
"ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
},
tokenizer_mode="mistral",
# TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
is_available_online=False,
),
"QwenVLForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen-VL",
@ -870,6 +879,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
use_original_num_layers=True,
max_model_len=10240,
),
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
"mistralai/Mistral-Large-3-675B-Instruct-2512",
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
is_available_online=False,
),
"LlamaForCausalLMEagle3": _HfExamplesInfo(
"Qwen/Qwen3-8B",
trust_remote_code=True,

View File

@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import (
],
),
),
(
{
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"unsupported_field": False,
"name": "get_current_time",
"parameters": {},
},
},
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"unsupported_field2": False,
"name": "get_current_time",
"parameters": {},
},
},
],
},
(
[
{
"role": "user",
"content": "What is the current local date and time?",
}
],
[
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
},
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
},
],
),
),
(
{
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"unsupported_field": False,
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
},
{
"type": "function",
"unsupported_field2": False,
"function": {
"description": "Fetch the current local date and time 2.",
"name": "get_current_time2",
"parameters": {"a": "1"},
},
},
],
},
(
[
{
"role": "user",
"content": "What is the current local date and time?",
}
],
[
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": {},
},
},
{
"type": "function",
"function": {
"description": "Fetch the current local date and time 2.",
"name": "get_current_time2",
"parameters": {"a": "1"},
},
},
],
),
),
],
)
def test_prepare_apply_chat_template_tools_and_messages(
@ -1108,13 +1220,6 @@ class TestMistralTokenizer:
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
assert (
mistral_tokenizer.decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
def test_decode_empty(
self,
@ -1140,6 +1245,45 @@ class TestMistralTokenizer:
== "<s>"
)
@pytest.mark.parametrize(
"skip_special_tokens,expected_tokens",
(
(
False,
(
["<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>"],
["<s>[INST]Hello world ![/INST]Hello</s>"],
),
),
(True, (["Hello world ! Hello"], ["Hello world !Hello"])),
),
)
def test_batch_decode(
self,
mistral_tokenizer: MistralTokenizer,
skip_special_tokens: bool,
expected_tokens: tuple[str, str],
):
ids = (
[[1, 3, 23325, 2294, 1686, 4, 23325, 2]],
[[1, 3, 22177, 4304, 2662, 4, 22177, 2]],
)
assert (
mistral_tokenizer.batch_decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
def test_batch_decode_empty(
self,
mistral_tokenizer: MistralTokenizer,
):
assert mistral_tokenizer.batch_decode(
[[]],
) == [""]
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
tokens = (
[

View File

@ -167,6 +167,7 @@ class SpeculativeConfig:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
@ -226,6 +227,9 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
return hf_config
def __post_init__(self):

View File

@ -80,7 +80,7 @@ class MistralToolParser(ToolParser):
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
)
else:
self.fn_name_regex = None

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
q_c = None
kv_lora = None
@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
hidden_states, q_c, positions, self.indexer_rope_emb
)
if llama_4_scaling is not None:
q *= llama_4_scaling
attn_out = self.mla_attn(
q,
kv_c_normed,

View File

@ -238,7 +238,7 @@ def get_rope(
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor

View File

@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
def _get_llama_4_scaling(
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
scaling = 1 + scaling_beta * torch.log(
1 + torch.floor(positions / original_max_position_embeddings)
)
# Broadcast over num_heads and head_dim
return scaling[..., None, None]
class DeepseekV2Attention(nn.Module):
def __init__(
self,
@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], value=0
@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
rope_parameters=config.rope_parameters,
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states)
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV2DecoderLayer(nn.Module):
@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)
if (
@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None
if llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=llama_4_scaling_config[
"original_max_position_embeddings"
],
scaling_beta=llama_4_scaling_config["beta"],
positions=positions,
)
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
model_cls = DeepseekV2Model
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import regex as re
import torch
from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM
class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
# fmt: off
remapping = {
r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
r"norm\.weight": "model.norm.weight", # noqa: E501
r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
r"output\.weight": "lm_head.weight", # noqa: E501
}
# fmt: on
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return super().load_weights(map(self._remap_mistral_to_ds, weights))
def _remap_mistral_to_ds(
self, weight: tuple[str, torch.Tensor]
) -> tuple[str, torch.Tensor]:
"""Remap Mistral parameters to DeepseekV2 parameters."""
name, loaded_weight = weight
for k, v in self.remapping.items():
match = re.fullmatch(k, name)
if match:
name = re.sub(k, v, name)
break
else:
raise ValueError(f"Cannot remap {name}")
# Remapping scale names. We could do this in the regex above but it
# would triple the number of lines for most layers.
if name.endswith(".qscale_act"):
name = re.sub(r"\.qscale_act$", ".input_scale", name)
elif name.endswith(".qscale_weight"):
name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
return name, loaded_weight

View File

@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import partial
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2Model,
)
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import (
_merge_multimodal_embeddings,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
@support_torch_compile
class EagleMistralLarge3Model(DeepseekV2Model):
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vllm_config = vllm_config
self.vocab_size = config.vocab_size
assert get_pp_group().world_size == 1
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.layers = nn.ModuleList(
[
DeepseekV2DecoderLayer(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
)
for i in range(self.config.num_hidden_layers)
]
)
self.start_layer = 0
self.end_layer = self.config.num_hidden_layers
self.fc = RowParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
bias=False,
input_is_parallel=False,
quant_config=quant_config,
return_bias=False,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
output = super().forward(
input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
)
assert isinstance(output, torch.Tensor)
return output
class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
remapping = MistralLarge3ForCausalLM.remapping | {
r"eagle_linear\.weight": r"model.fc.weight",
r"eagle_linear\.qscale_act": r"model.fc.input_scale",
r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
vllm_config.model_config = vllm_config.speculative_config.draft_model_config
# draft model quantization config may differ from target model
self.quant_config = VllmConfig.get_quantization_config(
vllm_config.speculative_config.draft_model_config, vllm_config.load_config
)
vllm_config.quant_config = self.quant_config
self.model_cls = partial(
EagleMistralLarge3Model, start_layer_id=target_layer_num
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = super().embed_input_ids(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
assert is_multimodal is not None
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
return hidden_states, hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Pretend we've loaded the embedding and lm_head weights
# (later copied from target model)
return super().load_weights(weights) | {
"model.embed_tokens.weight",
"lm_head.weight",
}
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = {
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleMistralLarge3ForCausalLM": (
"mistral_large_3_eagle",
"EagleMistralLarge3ForCausalLM",
),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@ -97,6 +97,8 @@ def _prepare_apply_chat_template_tools_and_messages(
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
from mistral_common.protocol.instruct.tool_calls import Function, Tool
if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
@ -139,6 +141,33 @@ def _prepare_apply_chat_template_tools_and_messages(
if function.get("description") is None:
function["description"] = ""
# We filter not supported arguments to avoid throwing an error.
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
tools_fields = set(Tool.model_fields.keys())
function_fields = set(Function.model_fields.keys())
for tool in tools:
tool_keys = list(tool.keys())
for tool_key in tool_keys:
if tool_key not in tools_fields:
tool.pop(tool_key)
logger.warning_once(
f"'{tool_key}' is not supported by mistral-common for tools. "
"It has been poped from the tool definition."
)
if tool["type"] == "function":
function_keys = list(tool["function"].keys())
for function_key in function_keys:
if function_key not in function_fields:
tool["function"].pop(function_key)
logger.warning_once(
f"'{function_key}' is not supported by mistral-common "
"for function tools. It has been poped from the "
"function definition."
)
else:
raise ValueError("mistral-common only supports function tools.")
return messages, tools
@ -410,6 +439,13 @@ class MistralTokenizer(TokenizerLike):
ids, skip_special_tokens=skip_special_tokens
)
def batch_decode(
self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
) -> str:
return self.transformers_tokenizer.batch_decode(
ids, skip_special_tokens=skip_special_tokens
)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,

View File

@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig):
pretrained_model_name_or_path, **kwargs
)
return cls.from_dict(config_dict, **kwargs)
def to_json_string(self, use_diff: bool = True) -> str:
# we override use_diff to False as initializing
# EAGLEConfig with default arguments is not supported
del use_diff
return super().to_json_string(use_diff=False)

View File

@ -18,9 +18,31 @@ def adapt_config_dict(
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
is_moe = bool(config_dict.get("moe"))
is_mistral_large_3 = (
is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
)
if config_dict.get("model_type") == "mamba":
config_dict["architectures"] = ["Mamba2ForCausalLM"]
elif bool(config_dict.get("moe")):
elif is_moe and is_mistral_large_3:
config_dict = _remap_moe_args(config_dict)
config_dict["model_type"] = "deepseek_v3"
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
assert "llama_4_scaling" in config_dict, (
"MistralLarge3 expect llama4 scaling config."
)
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
assert all(
[
key in config_dict["llama_4_scaling"]
for key in llama_4_scaling_config_keys
]
), (
"llama_4_scaling config should define the keys: "
f"{','.join(llama_4_scaling_config_keys)}"
)
elif is_moe:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict:
def _remap_mistral_quantization_args(config: dict) -> dict:
quantization = config.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
elif quantization.get("quant_method") == "compressed-tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
raise ValueError(f"Found unknown quantization='{quantization}' in config")
config["quantization_config"] = quantization_config
if config.get("quantization"):
quantization = config.pop("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
qscheme_act = quantization.get("qscheme_act")
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
)
is_dynamic = qscheme_act == "NO_SCALES"
config["quantization_config"] = {
"quant_method": "fp8",
"activation_scheme": "dynamic" if is_dynamic else "static",
}
else:
raise ValueError(f"Found unknown quantization='{quantization}' in config")
return config
@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict:
if quant_config:
config["quantization_config"] = quant_config
return config
def _remap_moe_args(config: dict) -> dict:
moe_config_map = {
"route_every_n": "moe_layer_freq",
"first_k_dense_replace": "first_k_dense_replace",
"num_experts_per_tok": "num_experts_per_tok",
"num_experts": "n_routed_experts",
"expert_hidden_dim": "moe_intermediate_size",
"routed_scale": "routed_scaling_factor",
"num_shared_experts": "n_shared_experts",
"num_expert_groups": "n_group",
"num_expert_groups_per_tok": "topk_group",
}
moe_config = config.get("moe", {})
for old_name, new_name in moe_config_map.items():
if old_name in moe_config:
value = moe_config.pop(old_name)
config[new_name] = value
config["topk_method"] = None
config["norm_topk_prob"] = True
config["scoring_func"] = "softmax"
return config

View File

@ -1016,6 +1016,10 @@ class EagleProposer:
"Qwen3VLForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
self.model.config.image_token_index = (
target_model.config.vision_config.image_token_id
)
else:
self.model.config.image_token_index = (
target_model.config.image_token_index